Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- docs/transformers/tests/fixtures/spiece.model +3 -0
- docs/transformers/tests/fixtures/test_sentencepiece.model +3 -0
- docs/transformers/tests/fixtures/test_sentencepiece_bpe.model +3 -0
- docs/transformers/tests/fixtures/test_sentencepiece_bpe_char.model +3 -0
- docs/transformers/tests/fixtures/test_sentencepiece_no_bos.model +3 -0
- docs/transformers/tests/fixtures/test_sentencepiece_with_bytefallback.model +3 -0
- docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png +3 -0
- docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png +3 -0
- docs/transformers/tests/models/byt5/__init__.py +0 -0
- docs/transformers/tests/models/byt5/test_tokenization_byt5.py +366 -0
- docs/transformers/tests/models/camembert/__init__.py +0 -0
- docs/transformers/tests/models/camembert/test_modeling_tf_camembert.py +55 -0
- docs/transformers/tests/models/camembert/test_tokenization_camembert.py +220 -0
- docs/transformers/tests/models/canine/__init__.py +0 -0
- docs/transformers/tests/models/canine/test_modeling_canine.py +571 -0
- docs/transformers/tests/models/canine/test_tokenization_canine.py +339 -0
- docs/transformers/tests/models/chameleon/__init__.py +0 -0
- docs/transformers/tests/models/chameleon/test_image_processing_chameleon.py +204 -0
- docs/transformers/tests/models/chameleon/test_modeling_chameleon.py +481 -0
- docs/transformers/tests/models/chameleon/test_processor_chameleon.py +76 -0
- docs/transformers/tests/models/chinese_clip/__init__.py +0 -0
- docs/transformers/tests/models/chinese_clip/test_image_processing_chinese_clip.py +175 -0
- docs/transformers/tests/models/chinese_clip/test_modeling_chinese_clip.py +762 -0
- docs/transformers/tests/models/chinese_clip/test_processor_chinese_clip.py +217 -0
- docs/transformers/tests/models/clap/__init__.py +0 -0
- docs/transformers/tests/models/clap/test_feature_extraction_clap.py +546 -0
- docs/transformers/tests/models/clap/test_modeling_clap.py +755 -0
- docs/transformers/tests/models/clap/test_processor_clap.py +125 -0
- docs/transformers/tests/models/clip/__init__.py +0 -0
- docs/transformers/tests/models/clip/test_image_processing_clip.py +128 -0
- docs/transformers/tests/models/clip/test_modeling_clip.py +948 -0
- docs/transformers/tests/models/clip/test_modeling_flax_clip.py +468 -0
- docs/transformers/tests/models/clip/test_modeling_tf_clip.py +662 -0
- docs/transformers/tests/models/clip/test_processor_clip.py +199 -0
- docs/transformers/tests/models/clip/test_tokenization_clip.py +192 -0
- docs/transformers/tests/models/clipseg/__init__.py +0 -0
- docs/transformers/tests/models/clipseg/test_modeling_clipseg.py +714 -0
- docs/transformers/tests/models/clipseg/test_processor_clipseg.py +194 -0
- docs/transformers/tests/models/clvp/__init__.py +0 -0
- docs/transformers/tests/models/clvp/test_feature_extraction_clvp.py +240 -0
- docs/transformers/tests/models/clvp/test_modeling_clvp.py +640 -0
- docs/transformers/tests/models/clvp/test_processor_clvp.py +136 -0
- docs/transformers/tests/models/clvp/test_tokenization_clvp.py +317 -0
- docs/transformers/tests/models/code_llama/__init__.py +0 -0
- docs/transformers/tests/models/code_llama/test_tokenization_code_llama.py +653 -0
- docs/transformers/tests/models/codegen/__init__.py +0 -0
- docs/transformers/tests/models/codegen/test_modeling_codegen.py +492 -0
- docs/transformers/tests/models/codegen/test_tokenization_codegen.py +329 -0
- docs/transformers/tests/models/cohere/__init__.py +0 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
asset/banner.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 36 |
asset/banner.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
docs/resources/web-ui.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
docs/resources/dpo_data.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png filter=lfs diff=lfs merge=lfs -text
|
docs/transformers/tests/fixtures/spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336
|
| 3 |
+
size 760289
|
docs/transformers/tests/fixtures/test_sentencepiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8dfd1eae4522281b1b839eab877a791befec7a1663a41c814c77d9c89c748f2d
|
| 3 |
+
size 253154
|
docs/transformers/tests/fixtures/test_sentencepiece_bpe.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c4de78f5d11ee09141165d31da7dad97e809dd6ee7b52a0cbc6d76a973028286
|
| 3 |
+
size 251527
|
docs/transformers/tests/fixtures/test_sentencepiece_bpe_char.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fcc48f3e225f627b1641db410ceb0c8649bd2b0c982e150b03f8be3728ab560
|
| 3 |
+
size 238473
|
docs/transformers/tests/fixtures/test_sentencepiece_no_bos.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f3af97c2e7bc51d781e7440aa33deee7f482eac819d23fd24af80e7b4ce2646
|
| 3 |
+
size 253134
|
docs/transformers/tests/fixtures/test_sentencepiece_with_bytefallback.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c61ecce43369fc3bab9566464f0e71f3ad75dc2319a5aadc2a561e3e312502e3
|
| 3 |
+
size 270096
|
docs/transformers/tests/fixtures/tests_samples/COCO/000000004016.png
ADDED
|
Git LFS Details
|
docs/transformers/tests/fixtures/tests_samples/COCO/000000039769.png
ADDED
|
Git LFS Details
|
docs/transformers/tests/models/byt5/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/byt5/test_tokenization_byt5.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 Google T5 Authors and HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import tempfile
|
| 20 |
+
import unittest
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
|
| 23 |
+
from transformers import AddedToken, BatchEncoding, ByT5Tokenizer
|
| 24 |
+
from transformers.utils import cached_property, is_tf_available, is_torch_available
|
| 25 |
+
|
| 26 |
+
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if is_torch_available():
|
| 30 |
+
FRAMEWORK = "pt"
|
| 31 |
+
elif is_tf_available():
|
| 32 |
+
FRAMEWORK = "tf"
|
| 33 |
+
else:
|
| 34 |
+
FRAMEWORK = "jax"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 38 |
+
tokenizer_class = ByT5Tokenizer
|
| 39 |
+
test_rust_tokenizer = False
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def setUpClass(cls):
|
| 43 |
+
super().setUpClass()
|
| 44 |
+
tokenizer = ByT5Tokenizer()
|
| 45 |
+
tokenizer.save_pretrained(cls.tmpdirname)
|
| 46 |
+
|
| 47 |
+
@cached_property
|
| 48 |
+
def t5_base_tokenizer(self):
|
| 49 |
+
return ByT5Tokenizer.from_pretrained("google/byt5-small")
|
| 50 |
+
|
| 51 |
+
@classmethod
|
| 52 |
+
@use_cache_if_possible
|
| 53 |
+
@lru_cache(maxsize=64)
|
| 54 |
+
def get_tokenizer(cls, pretrained_name=None, **kwargs) -> ByT5Tokenizer:
|
| 55 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 56 |
+
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
| 57 |
+
|
| 58 |
+
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> tuple[str, list]:
|
| 59 |
+
# XXX The default common tokenizer tests assume that every ID is decodable on its own.
|
| 60 |
+
# This assumption is invalid for ByT5 because single bytes might not be
|
| 61 |
+
# valid utf-8 (byte 128 for instance).
|
| 62 |
+
# Here we're overriding the smallest possible method to provide
|
| 63 |
+
# a clean sequence without making the same assumption.
|
| 64 |
+
|
| 65 |
+
toks = []
|
| 66 |
+
for i in range(len(tokenizer)):
|
| 67 |
+
try:
|
| 68 |
+
tok = tokenizer.decode([i], clean_up_tokenization_spaces=False)
|
| 69 |
+
except UnicodeDecodeError:
|
| 70 |
+
pass
|
| 71 |
+
toks.append((i, tok))
|
| 72 |
+
|
| 73 |
+
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
| 74 |
+
toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks))
|
| 75 |
+
if max_length is not None and len(toks) > max_length:
|
| 76 |
+
toks = toks[:max_length]
|
| 77 |
+
if min_length is not None and len(toks) < min_length and len(toks) > 0:
|
| 78 |
+
while len(toks) < min_length:
|
| 79 |
+
toks = toks + toks
|
| 80 |
+
# toks_str = [t[1] for t in toks]
|
| 81 |
+
toks_ids = [t[0] for t in toks]
|
| 82 |
+
|
| 83 |
+
# Ensure consistency
|
| 84 |
+
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
| 85 |
+
if " " not in output_txt and len(toks_ids) > 1:
|
| 86 |
+
output_txt = (
|
| 87 |
+
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
| 88 |
+
+ " "
|
| 89 |
+
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
| 90 |
+
)
|
| 91 |
+
if with_prefix_space:
|
| 92 |
+
output_txt = " " + output_txt
|
| 93 |
+
output_ids = tokenizer.encode(output_txt, add_special_tokens=False)
|
| 94 |
+
return output_txt, output_ids
|
| 95 |
+
|
| 96 |
+
def test_eos_treatment(self):
|
| 97 |
+
tokenizer = self.t5_base_tokenizer
|
| 98 |
+
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
|
| 99 |
+
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
|
| 100 |
+
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
|
| 101 |
+
|
| 102 |
+
def test_multibytes_char(self):
|
| 103 |
+
tokenizer = self.t5_base_tokenizer
|
| 104 |
+
src_text = "Unicode €."
|
| 105 |
+
encoded = tokenizer(src_text)
|
| 106 |
+
encoded_ids = [88, 113, 108, 102, 114, 103, 104, 35, 229, 133, 175, 49, 1]
|
| 107 |
+
self.assertEqual(encoded["input_ids"], encoded_ids)
|
| 108 |
+
|
| 109 |
+
# decoding
|
| 110 |
+
decoded = tokenizer.decode(encoded_ids)
|
| 111 |
+
self.assertEqual(decoded, "Unicode €.</s>")
|
| 112 |
+
|
| 113 |
+
encoded = tokenizer("e è é ê ë")
|
| 114 |
+
encoded_ids = [104, 35, 198, 171, 35, 198, 172, 35, 198, 173, 35, 198, 174, 1]
|
| 115 |
+
self.assertEqual(encoded["input_ids"], encoded_ids)
|
| 116 |
+
# decoding
|
| 117 |
+
decoded = tokenizer.decode(encoded_ids)
|
| 118 |
+
self.assertEqual(decoded, "e è é ê ë</s>")
|
| 119 |
+
|
| 120 |
+
# encode/decode, but with `encode` instead of `__call__`
|
| 121 |
+
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "e è é ê ë</s>")
|
| 122 |
+
|
| 123 |
+
def test_prepare_batch_integration(self):
|
| 124 |
+
tokenizer = self.t5_base_tokenizer
|
| 125 |
+
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
| 126 |
+
expected_src_tokens = [68, 35, 111, 114, 113, 106, 35, 115, 100, 117, 100, 106, 117, 100, 115, 107, 35, 105, 114, 117, 35, 118, 120, 112, 112, 100, 117, 108, 125, 100, 119, 108, 114, 113, 49, 1, 0] # fmt: skip
|
| 127 |
+
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
| 128 |
+
self.assertIsInstance(batch, BatchEncoding)
|
| 129 |
+
|
| 130 |
+
if FRAMEWORK != "jax":
|
| 131 |
+
result = list(batch.input_ids.numpy()[0])
|
| 132 |
+
else:
|
| 133 |
+
result = list(batch.input_ids.tolist()[0])
|
| 134 |
+
|
| 135 |
+
self.assertListEqual(expected_src_tokens, result)
|
| 136 |
+
|
| 137 |
+
self.assertEqual((2, 37), batch.input_ids.shape)
|
| 138 |
+
self.assertEqual((2, 37), batch.attention_mask.shape)
|
| 139 |
+
|
| 140 |
+
def test_empty_target_text(self):
|
| 141 |
+
tokenizer = self.t5_base_tokenizer
|
| 142 |
+
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
|
| 143 |
+
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
|
| 144 |
+
# check if input_ids are returned and no decoder_input_ids
|
| 145 |
+
self.assertIn("input_ids", batch)
|
| 146 |
+
self.assertIn("attention_mask", batch)
|
| 147 |
+
self.assertNotIn("decoder_input_ids", batch)
|
| 148 |
+
self.assertNotIn("decoder_attention_mask", batch)
|
| 149 |
+
|
| 150 |
+
def test_max_length_integration(self):
|
| 151 |
+
tokenizer = self.t5_base_tokenizer
|
| 152 |
+
tgt_text = [
|
| 153 |
+
"Summary of the text.",
|
| 154 |
+
"Another summary.",
|
| 155 |
+
]
|
| 156 |
+
targets = tokenizer(
|
| 157 |
+
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
| 158 |
+
)
|
| 159 |
+
self.assertEqual(32, targets["input_ids"].shape[1])
|
| 160 |
+
|
| 161 |
+
def test_eos_in_input(self):
|
| 162 |
+
tokenizer = self.t5_base_tokenizer
|
| 163 |
+
src_text = ["A long paragraph for summarization. </s>"]
|
| 164 |
+
tgt_text = ["Summary of the text. </s>"]
|
| 165 |
+
expected_src_tokens = [68, 35, 111, 114, 113, 106, 35, 115, 100, 117, 100, 106, 117, 100, 115, 107, 35, 105, 114, 117, 35, 118, 120, 112, 112, 100, 117, 108, 125, 100, 119, 108, 114, 113, 49, 35, 1] # fmt: skip
|
| 166 |
+
expected_tgt_tokens = [86, 120, 112, 112, 100, 117, 124, 35, 114, 105, 35, 119, 107, 104, 35, 119, 104, 123, 119, 49, 35, 1] # fmt: skip
|
| 167 |
+
|
| 168 |
+
batch = tokenizer(src_text, text_target=tgt_text)
|
| 169 |
+
|
| 170 |
+
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
| 171 |
+
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
|
| 172 |
+
|
| 173 |
+
# cannot use default save_and_load_tokenizer test method because tokenizer has no vocab
|
| 174 |
+
def test_save_and_load_tokenizer(self):
|
| 175 |
+
# safety check on max_len default value so we are sure the test works
|
| 176 |
+
tokenizers = self.get_tokenizers()
|
| 177 |
+
for tokenizer in tokenizers:
|
| 178 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 179 |
+
self.assertNotEqual(tokenizer.model_max_length, 42)
|
| 180 |
+
|
| 181 |
+
# Now let's start the test
|
| 182 |
+
tokenizers = self.get_tokenizers()
|
| 183 |
+
for tokenizer in tokenizers:
|
| 184 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 185 |
+
# Isolate this from the other tests because we save additional tokens/etc
|
| 186 |
+
tmpdirname = tempfile.mkdtemp()
|
| 187 |
+
|
| 188 |
+
sample_text = " He is very happy, UNwant\u00e9d,running"
|
| 189 |
+
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
| 190 |
+
tokenizer.save_pretrained(tmpdirname)
|
| 191 |
+
|
| 192 |
+
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
| 193 |
+
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
| 194 |
+
self.assertListEqual(before_tokens, after_tokens)
|
| 195 |
+
|
| 196 |
+
shutil.rmtree(tmpdirname)
|
| 197 |
+
|
| 198 |
+
tokenizers = self.get_tokenizers(model_max_length=42)
|
| 199 |
+
for tokenizer in tokenizers:
|
| 200 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 201 |
+
# Isolate this from the other tests because we save additional tokens/etc
|
| 202 |
+
tmpdirname = tempfile.mkdtemp()
|
| 203 |
+
|
| 204 |
+
sample_text = " He is very happy, UNwant\u00e9d,running"
|
| 205 |
+
tokenizer.add_tokens(["bim", "bambam"])
|
| 206 |
+
additional_special_tokens = tokenizer.additional_special_tokens
|
| 207 |
+
additional_special_tokens.append("new_additional_special_token")
|
| 208 |
+
tokenizer.add_special_tokens(
|
| 209 |
+
{"additional_special_tokens": additional_special_tokens}, replace_additional_special_tokens=False
|
| 210 |
+
)
|
| 211 |
+
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
| 212 |
+
tokenizer.save_pretrained(tmpdirname)
|
| 213 |
+
|
| 214 |
+
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
| 215 |
+
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
| 216 |
+
self.assertListEqual(before_tokens, after_tokens)
|
| 217 |
+
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
|
| 218 |
+
self.assertEqual(after_tokenizer.model_max_length, 42)
|
| 219 |
+
|
| 220 |
+
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
|
| 221 |
+
self.assertEqual(tokenizer.model_max_length, 43)
|
| 222 |
+
|
| 223 |
+
shutil.rmtree(tmpdirname)
|
| 224 |
+
|
| 225 |
+
# There is a conflict between the default value of extra_ids and adding a new special token through additional_special_tokens
|
| 226 |
+
# We need to add the extra_ids in the list of the arg additional_special_tokens
|
| 227 |
+
def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self):
|
| 228 |
+
tokenizer_list = []
|
| 229 |
+
if self.test_slow_tokenizer:
|
| 230 |
+
tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))
|
| 231 |
+
|
| 232 |
+
if self.test_rust_tokenizer:
|
| 233 |
+
tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))
|
| 234 |
+
|
| 235 |
+
for tokenizer_class, tokenizer_utils in tokenizer_list:
|
| 236 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 237 |
+
tokenizer_utils.save_pretrained(tmp_dir)
|
| 238 |
+
|
| 239 |
+
with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file:
|
| 240 |
+
special_tokens_map = json.load(json_file)
|
| 241 |
+
|
| 242 |
+
with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file:
|
| 243 |
+
tokenizer_config = json.load(json_file)
|
| 244 |
+
|
| 245 |
+
added_tokens_extra_ids = [f"<extra_id_{i}>" for i in range(125)]
|
| 246 |
+
|
| 247 |
+
special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [
|
| 248 |
+
"an_additional_special_token"
|
| 249 |
+
]
|
| 250 |
+
tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [
|
| 251 |
+
"an_additional_special_token"
|
| 252 |
+
]
|
| 253 |
+
|
| 254 |
+
with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile:
|
| 255 |
+
json.dump(special_tokens_map, outfile)
|
| 256 |
+
with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile:
|
| 257 |
+
json.dump(tokenizer_config, outfile)
|
| 258 |
+
|
| 259 |
+
# the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes
|
| 260 |
+
# into account the new value of additional_special_tokens given in the "tokenizer_config.json" and
|
| 261 |
+
# "special_tokens_map.json" files
|
| 262 |
+
tokenizer_without_change_in_init = tokenizer_class.from_pretrained(
|
| 263 |
+
tmp_dir,
|
| 264 |
+
)
|
| 265 |
+
self.assertIn(
|
| 266 |
+
"an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens
|
| 267 |
+
)
|
| 268 |
+
# self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # ByT5Tokenization no vocab
|
| 269 |
+
self.assertEqual(
|
| 270 |
+
["an_additional_special_token"],
|
| 271 |
+
tokenizer_without_change_in_init.convert_ids_to_tokens(
|
| 272 |
+
tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"])
|
| 273 |
+
),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Now we test that we can change the value of additional_special_tokens in the from_pretrained
|
| 277 |
+
new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)]
|
| 278 |
+
tokenizer = tokenizer_class.from_pretrained(
|
| 279 |
+
tmp_dir,
|
| 280 |
+
additional_special_tokens=new_added_tokens,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens)
|
| 284 |
+
self.assertEqual(
|
| 285 |
+
["a_new_additional_special_token"],
|
| 286 |
+
tokenizer.convert_ids_to_tokens(
|
| 287 |
+
tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"])
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def test_decode_single_bytes(self):
|
| 292 |
+
tokenizer_list = []
|
| 293 |
+
if self.test_slow_tokenizer:
|
| 294 |
+
tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))
|
| 295 |
+
|
| 296 |
+
if self.test_rust_tokenizer:
|
| 297 |
+
tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))
|
| 298 |
+
|
| 299 |
+
for tokenizer_class, tokenizer_utils in tokenizer_list:
|
| 300 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 301 |
+
tokenizer_utils.save_pretrained(tmp_dir)
|
| 302 |
+
|
| 303 |
+
tokenizer = tokenizer_class.from_pretrained(tmp_dir)
|
| 304 |
+
|
| 305 |
+
self.assertTrue(tokenizer.decode([255]) == "")
|
| 306 |
+
|
| 307 |
+
@unittest.skip(reason="ByT5Tokenizer does not have a vocabulary")
|
| 308 |
+
def test_get_vocab(self):
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
@unittest.skip(reason="inputs cannot be pretokenized as ids depend on whole input string")
|
| 312 |
+
def test_pretokenized_inputs(self):
|
| 313 |
+
pass
|
| 314 |
+
|
| 315 |
+
@unittest.skip(reason="ByT5Tokenizer does not have a vocabulary")
|
| 316 |
+
def test_conversion_reversible(self):
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
def test_convert_tokens_to_string_format(self):
|
| 320 |
+
# The default common tokenizer tests uses invalid tokens for ByT5 that can only accept one-character strings
|
| 321 |
+
# and special added tokens as tokens
|
| 322 |
+
tokenizers = self.get_tokenizers(fast=True, do_lower_case=True)
|
| 323 |
+
for tokenizer in tokenizers:
|
| 324 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 325 |
+
tokens = ["t", "h", "i", "s", " ", "i", "s", " ", "a", " ", "t", "e", "x", "t", "</s>"]
|
| 326 |
+
string = tokenizer.convert_tokens_to_string(tokens)
|
| 327 |
+
|
| 328 |
+
self.assertIsInstance(string, str)
|
| 329 |
+
|
| 330 |
+
# We need a different implementation of the test of the same name defined in TokenizerTesterMixin because this tokenizer
|
| 331 |
+
# doesn't have a vocab
|
| 332 |
+
def test_tokenizers_common_ids_setters(self):
|
| 333 |
+
tokenizers = self.get_tokenizers()
|
| 334 |
+
for tokenizer in tokenizers:
|
| 335 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 336 |
+
attributes_list = [
|
| 337 |
+
"bos_token",
|
| 338 |
+
"eos_token",
|
| 339 |
+
"unk_token",
|
| 340 |
+
"sep_token",
|
| 341 |
+
"pad_token",
|
| 342 |
+
"cls_token",
|
| 343 |
+
"mask_token",
|
| 344 |
+
]
|
| 345 |
+
|
| 346 |
+
token_id_to_test_setters = 0
|
| 347 |
+
token_to_test_setters = tokenizer.convert_ids_to_tokens(
|
| 348 |
+
token_id_to_test_setters, skip_special_tokens=False
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
for attr in attributes_list:
|
| 352 |
+
setattr(tokenizer, attr + "_id", None)
|
| 353 |
+
self.assertEqual(getattr(tokenizer, attr), None)
|
| 354 |
+
self.assertEqual(getattr(tokenizer, attr + "_id"), None)
|
| 355 |
+
|
| 356 |
+
setattr(tokenizer, attr + "_id", token_id_to_test_setters)
|
| 357 |
+
self.assertEqual(getattr(tokenizer, attr), token_to_test_setters)
|
| 358 |
+
self.assertEqual(getattr(tokenizer, attr + "_id"), token_id_to_test_setters)
|
| 359 |
+
|
| 360 |
+
setattr(tokenizer, "additional_special_tokens_ids", [])
|
| 361 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [])
|
| 362 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [])
|
| 363 |
+
|
| 364 |
+
setattr(tokenizer, "additional_special_tokens_ids", [token_id_to_test_setters])
|
| 365 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [token_to_test_setters])
|
| 366 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [token_id_to_test_setters])
|
docs/transformers/tests/models/camembert/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/camembert/test_modeling_tf_camembert.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from transformers import is_tf_available
|
| 20 |
+
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if is_tf_available():
|
| 24 |
+
import numpy as np
|
| 25 |
+
import tensorflow as tf
|
| 26 |
+
|
| 27 |
+
from transformers import TFCamembertModel
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@require_tf
|
| 31 |
+
@require_sentencepiece
|
| 32 |
+
@require_tokenizers
|
| 33 |
+
class TFCamembertModelIntegrationTest(unittest.TestCase):
|
| 34 |
+
@slow
|
| 35 |
+
def test_output_embeds_base_model(self):
|
| 36 |
+
model = TFCamembertModel.from_pretrained("jplu/tf-camembert-base")
|
| 37 |
+
|
| 38 |
+
input_ids = tf.convert_to_tensor(
|
| 39 |
+
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]],
|
| 40 |
+
dtype=tf.int32,
|
| 41 |
+
) # J'aime le camembert !"
|
| 42 |
+
|
| 43 |
+
output = model(input_ids)["last_hidden_state"]
|
| 44 |
+
expected_shape = tf.TensorShape((1, 10, 768))
|
| 45 |
+
self.assertEqual(output.shape, expected_shape)
|
| 46 |
+
# compare the actual values for a slice.
|
| 47 |
+
expected_slice = tf.convert_to_tensor(
|
| 48 |
+
[[[-0.0254, 0.0235, 0.1027], [0.0606, -0.1811, -0.0418], [-0.1561, -0.1127, 0.2687]]],
|
| 49 |
+
dtype=tf.float32,
|
| 50 |
+
)
|
| 51 |
+
# camembert = torch.hub.load('pytorch/fairseq', 'camembert.v0')
|
| 52 |
+
# camembert.eval()
|
| 53 |
+
# expected_slice = roberta.model.forward(input_ids)[0][:, :3, :3].detach()
|
| 54 |
+
|
| 55 |
+
self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))
|
docs/transformers/tests/models/camembert/test_tokenization_camembert.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import tempfile
|
| 16 |
+
import unittest
|
| 17 |
+
from tempfile import TemporaryDirectory
|
| 18 |
+
|
| 19 |
+
from transformers import AddedToken, CamembertTokenizer, CamembertTokenizerFast
|
| 20 |
+
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
|
| 21 |
+
from transformers.utils import is_torch_available
|
| 22 |
+
|
| 23 |
+
from ...test_tokenization_common import TokenizerTesterMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
| 27 |
+
SAMPLE_BPE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe.model")
|
| 28 |
+
|
| 29 |
+
FRAMEWORK = "pt" if is_torch_available() else "tf"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@require_sentencepiece
|
| 33 |
+
@require_tokenizers
|
| 34 |
+
class CamembertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 35 |
+
from_pretrained_id = "almanach/camembert-base"
|
| 36 |
+
tokenizer_class = CamembertTokenizer
|
| 37 |
+
rust_tokenizer_class = CamembertTokenizerFast
|
| 38 |
+
test_rust_tokenizer = True
|
| 39 |
+
test_sentencepiece = True
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def setUpClass(cls):
|
| 43 |
+
super().setUpClass()
|
| 44 |
+
|
| 45 |
+
# We have a SentencePiece fixture for testing
|
| 46 |
+
tokenizer = CamembertTokenizer(SAMPLE_VOCAB)
|
| 47 |
+
tokenizer.save_pretrained(cls.tmpdirname)
|
| 48 |
+
|
| 49 |
+
@unittest.skip(
|
| 50 |
+
"Token maps are not equal because someone set the probability of ('<unk>NOTUSED', -100), so it's never encoded for fast"
|
| 51 |
+
)
|
| 52 |
+
def test_special_tokens_map_equal(self):
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
def test_convert_token_and_id(self):
|
| 56 |
+
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
|
| 57 |
+
token = "<pad>"
|
| 58 |
+
token_id = 1 # 1 is the offset id, but in the spm vocab it's 3
|
| 59 |
+
|
| 60 |
+
self.assertEqual(self.get_tokenizer().convert_tokens_to_ids(token), token_id)
|
| 61 |
+
self.assertEqual(self.get_tokenizer().convert_ids_to_tokens(token_id), token)
|
| 62 |
+
|
| 63 |
+
def test_get_vocab(self):
|
| 64 |
+
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
|
| 65 |
+
|
| 66 |
+
self.assertEqual(vocab_keys[0], "<s>NOTUSED")
|
| 67 |
+
self.assertEqual(vocab_keys[1], "<pad>")
|
| 68 |
+
self.assertEqual(vocab_keys[-1], "<mask>")
|
| 69 |
+
self.assertEqual(len(vocab_keys), 1_005)
|
| 70 |
+
|
| 71 |
+
def test_vocab_size(self):
|
| 72 |
+
self.assertEqual(self.get_tokenizer().vocab_size, 1_000)
|
| 73 |
+
|
| 74 |
+
def test_rust_and_python_bpe_tokenizers(self):
|
| 75 |
+
tokenizer = CamembertTokenizer(SAMPLE_BPE_VOCAB)
|
| 76 |
+
with TemporaryDirectory() as tmpdirname:
|
| 77 |
+
tokenizer.save_pretrained(tmpdirname)
|
| 78 |
+
rust_tokenizer = CamembertTokenizerFast.from_pretrained(tmpdirname)
|
| 79 |
+
|
| 80 |
+
sequence = "I was born in 92000, and this is falsé."
|
| 81 |
+
|
| 82 |
+
ids = tokenizer.encode(sequence)
|
| 83 |
+
rust_ids = rust_tokenizer.encode(sequence)
|
| 84 |
+
self.assertListEqual(ids, rust_ids)
|
| 85 |
+
|
| 86 |
+
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
| 87 |
+
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
| 88 |
+
self.assertListEqual(ids, rust_ids)
|
| 89 |
+
|
| 90 |
+
# <unk> tokens are not the same for `rust` than for `slow`.
|
| 91 |
+
# Because spm gives back raw token instead of `unk` in EncodeAsPieces
|
| 92 |
+
# tokens = tokenizer.tokenize(sequence)
|
| 93 |
+
tokens = tokenizer.convert_ids_to_tokens(ids)
|
| 94 |
+
rust_tokens = rust_tokenizer.tokenize(sequence)
|
| 95 |
+
self.assertListEqual(tokens, rust_tokens)
|
| 96 |
+
|
| 97 |
+
def test_rust_and_python_full_tokenizers(self):
|
| 98 |
+
if not self.test_rust_tokenizer:
|
| 99 |
+
self.skipTest(reason="test_rust_tokenizer is set to False")
|
| 100 |
+
|
| 101 |
+
tokenizer = self.get_tokenizer()
|
| 102 |
+
rust_tokenizer = self.get_rust_tokenizer()
|
| 103 |
+
|
| 104 |
+
sequence = "I was born in 92000, and this is falsé."
|
| 105 |
+
|
| 106 |
+
tokens = tokenizer.tokenize(sequence)
|
| 107 |
+
rust_tokens = rust_tokenizer.tokenize(sequence)
|
| 108 |
+
self.assertListEqual(tokens, rust_tokens)
|
| 109 |
+
|
| 110 |
+
ids = tokenizer.encode(sequence, add_special_tokens=False)
|
| 111 |
+
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
| 112 |
+
self.assertListEqual(ids, rust_ids)
|
| 113 |
+
|
| 114 |
+
rust_tokenizer = self.get_rust_tokenizer()
|
| 115 |
+
ids = tokenizer.encode(sequence)
|
| 116 |
+
rust_ids = rust_tokenizer.encode(sequence)
|
| 117 |
+
self.assertListEqual(ids, rust_ids)
|
| 118 |
+
|
| 119 |
+
@slow
|
| 120 |
+
def test_tokenizer_integration(self):
|
| 121 |
+
expected_encoding = {'input_ids': [[5, 54, 7196, 297, 30, 23, 776, 18, 11, 3215, 3705, 8252, 22, 3164, 1181, 2116, 29, 16, 813, 25, 791, 3314, 20, 3446, 38, 27575, 120, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [5, 468, 17, 11, 9088, 20, 1517, 8, 22804, 18818, 10, 38, 629, 607, 607, 142, 19, 7196, 867, 56, 10326, 24, 2267, 20, 416, 5072, 15612, 233, 734, 7, 2399, 27, 16, 3015, 1649, 7, 24, 20, 4338, 2399, 27, 13, 3400, 14, 13, 6189, 8, 930, 9, 6]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
|
| 122 |
+
|
| 123 |
+
# camembert is a french model. So we also use french texts.
|
| 124 |
+
sequences = [
|
| 125 |
+
"Le transformeur est un modèle d'apprentissage profond introduit en 2017, "
|
| 126 |
+
"utilisé principalement dans le domaine du traitement automatique des langues (TAL).",
|
| 127 |
+
"À l'instar des réseaux de neurones récurrents (RNN), les transformeurs sont conçus "
|
| 128 |
+
"pour gérer des données séquentielles, telles que le langage naturel, pour des tâches "
|
| 129 |
+
"telles que la traduction et la synthèse de texte.",
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
self.tokenizer_integration_test_util(
|
| 133 |
+
expected_encoding=expected_encoding,
|
| 134 |
+
model_name="almanach/camembert-base",
|
| 135 |
+
revision="3a0641d9a1aeb7e848a74299e7e4c4bca216b4cf",
|
| 136 |
+
sequences=sequences,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Overwritten because we have to use from slow (online pretrained is wrong, the tokenizer.json has a whole)
|
| 140 |
+
def test_added_tokens_serialization(self):
|
| 141 |
+
self.maxDiff = None
|
| 142 |
+
|
| 143 |
+
# Utility to test the added vocab
|
| 144 |
+
def _test_added_vocab_and_eos(expected, tokenizer_class, expected_eos, temp_dir):
|
| 145 |
+
tokenizer = tokenizer_class.from_pretrained(temp_dir)
|
| 146 |
+
self.assertTrue(str(expected_eos) not in tokenizer.additional_special_tokens)
|
| 147 |
+
self.assertIn(new_eos, tokenizer.added_tokens_decoder.values())
|
| 148 |
+
self.assertEqual(tokenizer.added_tokens_decoder[tokenizer.eos_token_id], new_eos)
|
| 149 |
+
self.assertTrue(all(item in tokenizer.added_tokens_decoder.items() for item in expected.items()))
|
| 150 |
+
return tokenizer
|
| 151 |
+
|
| 152 |
+
new_eos = AddedToken("[NEW_EOS]", rstrip=False, lstrip=True, normalized=False, special=True)
|
| 153 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 154 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 155 |
+
# Load a slow tokenizer from the hub, init with the new token for fast to also include it
|
| 156 |
+
tokenizer = self.get_tokenizer(pretrained_name, eos_token=new_eos)
|
| 157 |
+
EXPECTED_ADDED_TOKENS_DECODER = tokenizer.added_tokens_decoder
|
| 158 |
+
with self.subTest("Hub -> Slow: Test loading a slow tokenizer from the hub)"):
|
| 159 |
+
self.assertEqual(tokenizer._special_tokens_map["eos_token"], new_eos)
|
| 160 |
+
self.assertIn(new_eos, list(tokenizer.added_tokens_decoder.values()))
|
| 161 |
+
|
| 162 |
+
with tempfile.TemporaryDirectory() as tmp_dir_2:
|
| 163 |
+
tokenizer.save_pretrained(tmp_dir_2)
|
| 164 |
+
with self.subTest(
|
| 165 |
+
"Hub -> Slow -> Slow: Test saving this slow tokenizer and reloading it in the fast class"
|
| 166 |
+
):
|
| 167 |
+
_test_added_vocab_and_eos(
|
| 168 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_2
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if self.rust_tokenizer_class is not None:
|
| 172 |
+
with self.subTest(
|
| 173 |
+
"Hub -> Slow -> Fast: Test saving this slow tokenizer and reloading it in the fast class"
|
| 174 |
+
):
|
| 175 |
+
tokenizer_fast = _test_added_vocab_and_eos(
|
| 176 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.rust_tokenizer_class, new_eos, tmp_dir_2
|
| 177 |
+
)
|
| 178 |
+
with tempfile.TemporaryDirectory() as tmp_dir_3:
|
| 179 |
+
tokenizer_fast.save_pretrained(tmp_dir_3)
|
| 180 |
+
with self.subTest(
|
| 181 |
+
"Hub -> Slow -> Fast -> Fast: Test saving this fast tokenizer and reloading it in the fast class"
|
| 182 |
+
):
|
| 183 |
+
_test_added_vocab_and_eos(
|
| 184 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.rust_tokenizer_class, new_eos, tmp_dir_3
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
with self.subTest(
|
| 188 |
+
"Hub -> Slow -> Fast -> Slow: Test saving this slow tokenizer and reloading it in the slow class"
|
| 189 |
+
):
|
| 190 |
+
_test_added_vocab_and_eos(
|
| 191 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.rust_tokenizer_class, new_eos, tmp_dir_3
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
with self.subTest("Hub -> Fast: Test loading a fast tokenizer from the hub)"):
|
| 195 |
+
if self.rust_tokenizer_class is not None:
|
| 196 |
+
tokenizer_fast = self.get_rust_tokenizer(pretrained_name, eos_token=new_eos, from_slow=True)
|
| 197 |
+
self.assertEqual(tokenizer_fast._special_tokens_map["eos_token"], new_eos)
|
| 198 |
+
self.assertIn(new_eos, list(tokenizer_fast.added_tokens_decoder.values()))
|
| 199 |
+
# We can't test the following because for BC we kept the default rstrip lstrip in slow not fast. Will comment once normalization is alright
|
| 200 |
+
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
| 201 |
+
with self.subTest("Hub -> Fast == Hub -> Slow: make sure slow and fast tokenizer match"):
|
| 202 |
+
self.assertTrue(
|
| 203 |
+
all(
|
| 204 |
+
item in tokenizer.added_tokens_decoder.items()
|
| 205 |
+
for item in EXPECTED_ADDED_TOKENS_DECODER.items()
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
EXPECTED_ADDED_TOKENS_DECODER = tokenizer_fast.added_tokens_decoder
|
| 210 |
+
with tempfile.TemporaryDirectory() as tmp_dir_4:
|
| 211 |
+
tokenizer_fast.save_pretrained(tmp_dir_4)
|
| 212 |
+
with self.subTest("Hub -> Fast -> Fast: saving Fast1 locally and loading"):
|
| 213 |
+
_test_added_vocab_and_eos(
|
| 214 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.rust_tokenizer_class, new_eos, tmp_dir_4
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with self.subTest("Hub -> Fast -> Slow: saving Fast1 locally and loading"):
|
| 218 |
+
_test_added_vocab_and_eos(
|
| 219 |
+
EXPECTED_ADDED_TOKENS_DECODER, self.tokenizer_class, new_eos, tmp_dir_4
|
| 220 |
+
)
|
docs/transformers/tests/models/canine/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/canine/test_modeling_canine.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch CANINE model."""
|
| 15 |
+
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
from transformers import CanineConfig, is_torch_available
|
| 19 |
+
from transformers.testing_utils import require_torch, slow, torch_device
|
| 20 |
+
|
| 21 |
+
from ...test_configuration_common import ConfigTester
|
| 22 |
+
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, global_rng, ids_tensor, random_attention_mask
|
| 23 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_torch_available():
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from transformers import (
|
| 30 |
+
CanineForMultipleChoice,
|
| 31 |
+
CanineForQuestionAnswering,
|
| 32 |
+
CanineForSequenceClassification,
|
| 33 |
+
CanineForTokenClassification,
|
| 34 |
+
CanineModel,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CanineModelTester:
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
parent,
|
| 42 |
+
batch_size=13,
|
| 43 |
+
seq_length=7,
|
| 44 |
+
is_training=True,
|
| 45 |
+
use_input_mask=True,
|
| 46 |
+
use_token_type_ids=True,
|
| 47 |
+
use_labels=True,
|
| 48 |
+
# let's use a vocab size that's way bigger than BERT's one
|
| 49 |
+
# NOTE: this is not a model parameter, just an input
|
| 50 |
+
vocab_size=100000,
|
| 51 |
+
hidden_size=32,
|
| 52 |
+
num_hidden_layers=2,
|
| 53 |
+
num_attention_heads=4,
|
| 54 |
+
intermediate_size=37,
|
| 55 |
+
hidden_act="gelu",
|
| 56 |
+
hidden_dropout_prob=0.1,
|
| 57 |
+
attention_probs_dropout_prob=0.1,
|
| 58 |
+
max_position_embeddings=512,
|
| 59 |
+
type_vocab_size=16,
|
| 60 |
+
type_sequence_label_size=2,
|
| 61 |
+
initializer_range=0.02,
|
| 62 |
+
num_labels=3,
|
| 63 |
+
num_choices=4,
|
| 64 |
+
num_hash_buckets=16,
|
| 65 |
+
scope=None,
|
| 66 |
+
):
|
| 67 |
+
self.parent = parent
|
| 68 |
+
self.batch_size = batch_size
|
| 69 |
+
self.seq_length = seq_length
|
| 70 |
+
self.is_training = is_training
|
| 71 |
+
self.use_input_mask = use_input_mask
|
| 72 |
+
self.use_token_type_ids = use_token_type_ids
|
| 73 |
+
self.use_labels = use_labels
|
| 74 |
+
self.vocab_size = vocab_size
|
| 75 |
+
self.hidden_size = hidden_size
|
| 76 |
+
self.num_hidden_layers = num_hidden_layers
|
| 77 |
+
self.num_attention_heads = num_attention_heads
|
| 78 |
+
self.intermediate_size = intermediate_size
|
| 79 |
+
self.hidden_act = hidden_act
|
| 80 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 81 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 82 |
+
self.max_position_embeddings = max_position_embeddings
|
| 83 |
+
self.type_vocab_size = type_vocab_size
|
| 84 |
+
self.type_sequence_label_size = type_sequence_label_size
|
| 85 |
+
self.initializer_range = initializer_range
|
| 86 |
+
self.num_labels = num_labels
|
| 87 |
+
self.num_choices = num_choices
|
| 88 |
+
self.num_hash_buckets = num_hash_buckets
|
| 89 |
+
self.scope = scope
|
| 90 |
+
|
| 91 |
+
def prepare_config_and_inputs(self):
|
| 92 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 93 |
+
|
| 94 |
+
input_mask = None
|
| 95 |
+
if self.use_input_mask:
|
| 96 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 97 |
+
|
| 98 |
+
token_type_ids = None
|
| 99 |
+
if self.use_token_type_ids:
|
| 100 |
+
token_type_ids = ids_tensor(input_ids.shape, self.type_vocab_size)
|
| 101 |
+
|
| 102 |
+
sequence_labels = None
|
| 103 |
+
token_labels = None
|
| 104 |
+
choice_labels = None
|
| 105 |
+
if self.use_labels:
|
| 106 |
+
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
| 107 |
+
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
| 108 |
+
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
| 109 |
+
|
| 110 |
+
config = self.get_config()
|
| 111 |
+
|
| 112 |
+
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 113 |
+
|
| 114 |
+
def get_config(self):
|
| 115 |
+
return CanineConfig(
|
| 116 |
+
hidden_size=self.hidden_size,
|
| 117 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 118 |
+
num_attention_heads=self.num_attention_heads,
|
| 119 |
+
intermediate_size=self.intermediate_size,
|
| 120 |
+
hidden_act=self.hidden_act,
|
| 121 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 122 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 123 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 124 |
+
type_vocab_size=self.type_vocab_size,
|
| 125 |
+
is_decoder=False,
|
| 126 |
+
initializer_range=self.initializer_range,
|
| 127 |
+
num_hash_buckets=self.num_hash_buckets,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def create_and_check_model(
|
| 131 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 132 |
+
):
|
| 133 |
+
model = CanineModel(config=config)
|
| 134 |
+
model.to(torch_device)
|
| 135 |
+
model.eval()
|
| 136 |
+
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
| 137 |
+
result = model(input_ids, token_type_ids=token_type_ids)
|
| 138 |
+
result = model(input_ids)
|
| 139 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 140 |
+
|
| 141 |
+
def create_and_check_for_question_answering(
|
| 142 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 143 |
+
):
|
| 144 |
+
model = CanineForQuestionAnswering(config=config)
|
| 145 |
+
model.to(torch_device)
|
| 146 |
+
model.eval()
|
| 147 |
+
result = model(
|
| 148 |
+
input_ids,
|
| 149 |
+
attention_mask=input_mask,
|
| 150 |
+
token_type_ids=token_type_ids,
|
| 151 |
+
start_positions=sequence_labels,
|
| 152 |
+
end_positions=sequence_labels,
|
| 153 |
+
)
|
| 154 |
+
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
| 155 |
+
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
| 156 |
+
|
| 157 |
+
def create_and_check_for_sequence_classification(
|
| 158 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 159 |
+
):
|
| 160 |
+
config.num_labels = self.num_labels
|
| 161 |
+
model = CanineForSequenceClassification(config)
|
| 162 |
+
model.to(torch_device)
|
| 163 |
+
model.eval()
|
| 164 |
+
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
| 165 |
+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
| 166 |
+
|
| 167 |
+
def create_and_check_for_token_classification(
|
| 168 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 169 |
+
):
|
| 170 |
+
config.num_labels = self.num_labels
|
| 171 |
+
model = CanineForTokenClassification(config=config)
|
| 172 |
+
model.to(torch_device)
|
| 173 |
+
model.eval()
|
| 174 |
+
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
| 175 |
+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
| 176 |
+
|
| 177 |
+
def create_and_check_for_multiple_choice(
|
| 178 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 179 |
+
):
|
| 180 |
+
config.num_choices = self.num_choices
|
| 181 |
+
model = CanineForMultipleChoice(config=config)
|
| 182 |
+
model.to(torch_device)
|
| 183 |
+
model.eval()
|
| 184 |
+
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
| 185 |
+
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
| 186 |
+
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
| 187 |
+
result = model(
|
| 188 |
+
multiple_choice_inputs_ids,
|
| 189 |
+
attention_mask=multiple_choice_input_mask,
|
| 190 |
+
token_type_ids=multiple_choice_token_type_ids,
|
| 191 |
+
labels=choice_labels,
|
| 192 |
+
)
|
| 193 |
+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
| 194 |
+
|
| 195 |
+
def prepare_config_and_inputs_for_common(self):
|
| 196 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 197 |
+
(
|
| 198 |
+
config,
|
| 199 |
+
input_ids,
|
| 200 |
+
token_type_ids,
|
| 201 |
+
input_mask,
|
| 202 |
+
sequence_labels,
|
| 203 |
+
token_labels,
|
| 204 |
+
choice_labels,
|
| 205 |
+
) = config_and_inputs
|
| 206 |
+
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
| 207 |
+
return config, inputs_dict
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@require_torch
|
| 211 |
+
class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 212 |
+
all_model_classes = (
|
| 213 |
+
(
|
| 214 |
+
CanineModel,
|
| 215 |
+
CanineForMultipleChoice,
|
| 216 |
+
CanineForQuestionAnswering,
|
| 217 |
+
CanineForSequenceClassification,
|
| 218 |
+
CanineForTokenClassification,
|
| 219 |
+
)
|
| 220 |
+
if is_torch_available()
|
| 221 |
+
else ()
|
| 222 |
+
)
|
| 223 |
+
pipeline_model_mapping = (
|
| 224 |
+
{
|
| 225 |
+
"feature-extraction": CanineModel,
|
| 226 |
+
"question-answering": CanineForQuestionAnswering,
|
| 227 |
+
"text-classification": CanineForSequenceClassification,
|
| 228 |
+
"token-classification": CanineForTokenClassification,
|
| 229 |
+
"zero-shot": CanineForSequenceClassification,
|
| 230 |
+
}
|
| 231 |
+
if is_torch_available()
|
| 232 |
+
else {}
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
test_mismatched_shapes = False
|
| 236 |
+
test_resize_embeddings = False
|
| 237 |
+
test_pruning = False
|
| 238 |
+
|
| 239 |
+
def setUp(self):
|
| 240 |
+
self.model_tester = CanineModelTester(self)
|
| 241 |
+
# we set has_text_modality to False as the config has no vocab_size attribute
|
| 242 |
+
self.config_tester = ConfigTester(self, config_class=CanineConfig, has_text_modality=False, hidden_size=37)
|
| 243 |
+
|
| 244 |
+
def test_config(self):
|
| 245 |
+
self.config_tester.run_common_tests()
|
| 246 |
+
|
| 247 |
+
def test_model(self):
|
| 248 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 249 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 250 |
+
|
| 251 |
+
def test_for_multiple_choice(self):
|
| 252 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 253 |
+
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
| 254 |
+
|
| 255 |
+
def test_for_question_answering(self):
|
| 256 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 257 |
+
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
| 258 |
+
|
| 259 |
+
def test_for_sequence_classification(self):
|
| 260 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 261 |
+
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
| 262 |
+
|
| 263 |
+
def test_for_token_classification(self):
|
| 264 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 265 |
+
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
| 266 |
+
|
| 267 |
+
def test_hidden_states_output(self):
|
| 268 |
+
def check_hidden_states_output(inputs_dict, config, model_class):
|
| 269 |
+
model = model_class(config)
|
| 270 |
+
model.to(torch_device)
|
| 271 |
+
model.eval()
|
| 272 |
+
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 275 |
+
|
| 276 |
+
hidden_states = outputs.hidden_states
|
| 277 |
+
# expected_num_layers equals num_hidden_layers of the deep encoder + 1, + 2 for the first shallow encoder, + 2
|
| 278 |
+
# for the final shallow encoder
|
| 279 |
+
expected_num_layers = self.model_tester.num_hidden_layers + 1 + 2 + 2
|
| 280 |
+
self.assertEqual(len(hidden_states), expected_num_layers)
|
| 281 |
+
|
| 282 |
+
seq_length = self.model_tester.seq_length
|
| 283 |
+
for i in range(expected_num_layers):
|
| 284 |
+
if (i < 2) or ((expected_num_layers - i) < 3):
|
| 285 |
+
# the expected length of the hidden_states of the first and final shallow encoders
|
| 286 |
+
# is equal to the seq_length
|
| 287 |
+
self.assertListEqual(
|
| 288 |
+
list(hidden_states[i].shape[-2:]),
|
| 289 |
+
[seq_length, self.model_tester.hidden_size],
|
| 290 |
+
)
|
| 291 |
+
else:
|
| 292 |
+
# the expected length of the hidden_states of the deep encoder need to be updated
|
| 293 |
+
# for CANINE since the seq length is downsampled
|
| 294 |
+
self.assertListEqual(
|
| 295 |
+
list(hidden_states[i].shape[-2:]),
|
| 296 |
+
[seq_length // config.downsampling_rate, self.model_tester.hidden_size],
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 300 |
+
|
| 301 |
+
for model_class in self.all_model_classes:
|
| 302 |
+
inputs_dict["output_hidden_states"] = True
|
| 303 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 304 |
+
|
| 305 |
+
# check that output_hidden_states also work using config
|
| 306 |
+
del inputs_dict["output_hidden_states"]
|
| 307 |
+
config.output_hidden_states = True
|
| 308 |
+
|
| 309 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 310 |
+
|
| 311 |
+
def test_attention_outputs(self):
|
| 312 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 313 |
+
config.return_dict = True
|
| 314 |
+
|
| 315 |
+
seq_len = getattr(self.model_tester, "seq_length", None)
|
| 316 |
+
|
| 317 |
+
for model_class in self.all_model_classes:
|
| 318 |
+
inputs_dict["output_attentions"] = True
|
| 319 |
+
inputs_dict["output_hidden_states"] = False
|
| 320 |
+
config.return_dict = True
|
| 321 |
+
model = model_class(config)
|
| 322 |
+
model.to(torch_device)
|
| 323 |
+
model.eval()
|
| 324 |
+
with torch.no_grad():
|
| 325 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 326 |
+
attentions = outputs.attentions
|
| 327 |
+
# we add + 2 due to the 2 shallow encoders
|
| 328 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers + 2)
|
| 329 |
+
|
| 330 |
+
# check that output_attentions also work using config
|
| 331 |
+
del inputs_dict["output_attentions"]
|
| 332 |
+
config.output_attentions = True
|
| 333 |
+
model = model_class(config)
|
| 334 |
+
model.to(torch_device)
|
| 335 |
+
model.eval()
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 338 |
+
attentions = outputs.attentions
|
| 339 |
+
# we add + 2 due to the 2 shallow encoders
|
| 340 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers + 2)
|
| 341 |
+
|
| 342 |
+
self.assertListEqual(
|
| 343 |
+
list(attentions[0].shape[-3:]),
|
| 344 |
+
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
| 345 |
+
)
|
| 346 |
+
out_len = len(outputs)
|
| 347 |
+
|
| 348 |
+
# Check attention is always last and order is fine
|
| 349 |
+
inputs_dict["output_attentions"] = True
|
| 350 |
+
inputs_dict["output_hidden_states"] = True
|
| 351 |
+
model = model_class(config)
|
| 352 |
+
model.to(torch_device)
|
| 353 |
+
model.eval()
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 356 |
+
|
| 357 |
+
if hasattr(self.model_tester, "num_hidden_states_types"):
|
| 358 |
+
added_hidden_states = self.model_tester.num_hidden_states_types
|
| 359 |
+
else:
|
| 360 |
+
added_hidden_states = 1
|
| 361 |
+
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
| 362 |
+
|
| 363 |
+
self_attentions = outputs.attentions
|
| 364 |
+
|
| 365 |
+
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers + 2)
|
| 366 |
+
self.assertListEqual(
|
| 367 |
+
list(self_attentions[0].shape[-3:]),
|
| 368 |
+
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def test_model_outputs_equivalence(self):
|
| 372 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 373 |
+
|
| 374 |
+
def set_nan_tensor_to_zero(t):
|
| 375 |
+
t[t != t] = 0
|
| 376 |
+
return t
|
| 377 |
+
|
| 378 |
+
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
| 381 |
+
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
| 382 |
+
|
| 383 |
+
def recursive_check(tuple_object, dict_object):
|
| 384 |
+
if isinstance(tuple_object, (list, tuple)):
|
| 385 |
+
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
| 386 |
+
recursive_check(tuple_iterable_value, dict_iterable_value)
|
| 387 |
+
elif tuple_object is None:
|
| 388 |
+
return
|
| 389 |
+
else:
|
| 390 |
+
self.assertTrue(
|
| 391 |
+
torch.allclose(
|
| 392 |
+
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
| 393 |
+
),
|
| 394 |
+
msg=(
|
| 395 |
+
"Tuple and dict output are not equal. Difference:"
|
| 396 |
+
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
| 397 |
+
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
| 398 |
+
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
| 399 |
+
),
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
recursive_check(tuple_output, dict_output)
|
| 403 |
+
|
| 404 |
+
for model_class in self.all_model_classes:
|
| 405 |
+
print(model_class)
|
| 406 |
+
model = model_class(config)
|
| 407 |
+
model.to(torch_device)
|
| 408 |
+
model.eval()
|
| 409 |
+
|
| 410 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 411 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 412 |
+
check_equivalence(model, tuple_inputs, dict_inputs)
|
| 413 |
+
|
| 414 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 415 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 416 |
+
check_equivalence(model, tuple_inputs, dict_inputs)
|
| 417 |
+
|
| 418 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 419 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 420 |
+
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
| 421 |
+
|
| 422 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 423 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 424 |
+
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
| 425 |
+
|
| 426 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 427 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 428 |
+
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
| 429 |
+
|
| 430 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 431 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 432 |
+
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
| 433 |
+
|
| 434 |
+
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 435 |
+
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 436 |
+
check_equivalence(
|
| 437 |
+
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
def test_headmasking(self):
|
| 441 |
+
if not self.test_head_masking:
|
| 442 |
+
self.skipTest(reason="test_head_masking is set to False")
|
| 443 |
+
|
| 444 |
+
global_rng.seed(42)
|
| 445 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 446 |
+
global_rng.seed()
|
| 447 |
+
|
| 448 |
+
inputs_dict["output_attentions"] = True
|
| 449 |
+
config.output_hidden_states = True
|
| 450 |
+
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
| 451 |
+
for model_class in self.all_model_classes:
|
| 452 |
+
model = model_class(config=configs_no_init)
|
| 453 |
+
model.to(torch_device)
|
| 454 |
+
model.eval()
|
| 455 |
+
|
| 456 |
+
# Prepare head_mask
|
| 457 |
+
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
|
| 458 |
+
head_mask = torch.ones(
|
| 459 |
+
self.model_tester.num_hidden_layers,
|
| 460 |
+
self.model_tester.num_attention_heads,
|
| 461 |
+
device=torch_device,
|
| 462 |
+
)
|
| 463 |
+
head_mask[0, 0] = 0
|
| 464 |
+
head_mask[-1, :-1] = 0
|
| 465 |
+
head_mask.requires_grad_(requires_grad=True)
|
| 466 |
+
inputs = self._prepare_for_class(inputs_dict, model_class).copy()
|
| 467 |
+
inputs["head_mask"] = head_mask
|
| 468 |
+
|
| 469 |
+
outputs = model(**inputs, return_dict=True)
|
| 470 |
+
|
| 471 |
+
# Test that we can get a gradient back for importance score computation
|
| 472 |
+
output = sum(t.sum() for t in outputs[0])
|
| 473 |
+
output = output.sum()
|
| 474 |
+
output.backward()
|
| 475 |
+
multihead_outputs = head_mask.grad
|
| 476 |
+
|
| 477 |
+
self.assertIsNotNone(multihead_outputs)
|
| 478 |
+
self.assertEqual(len(multihead_outputs), self.model_tester.num_hidden_layers)
|
| 479 |
+
|
| 480 |
+
def check_attentions_validity(attentions):
|
| 481 |
+
# Remove Nan
|
| 482 |
+
for t in attentions:
|
| 483 |
+
self.assertLess(
|
| 484 |
+
torch.sum(torch.isnan(t)), t.numel() / 4
|
| 485 |
+
) # Check we don't have more than 25% nans (arbitrary)
|
| 486 |
+
attentions = [
|
| 487 |
+
t.masked_fill(torch.isnan(t), 0.0) for t in attentions
|
| 488 |
+
] # remove them (the test is less complete)
|
| 489 |
+
|
| 490 |
+
self.assertAlmostEqual(attentions[1][..., 0, :, :].flatten().sum().item(), 0.0)
|
| 491 |
+
self.assertNotEqual(attentions[1][..., -1, :, :].flatten().sum().item(), 0.0)
|
| 492 |
+
self.assertAlmostEqual(attentions[-2][..., -2, :, :].flatten().sum().item(), 0.0)
|
| 493 |
+
self.assertNotEqual(attentions[-2][..., -1, :, :].flatten().sum().item(), 0.0)
|
| 494 |
+
|
| 495 |
+
check_attentions_validity(outputs.attentions)
|
| 496 |
+
|
| 497 |
+
@unittest.skip(reason="CANINE does not have a get_input_embeddings() method.")
|
| 498 |
+
def test_inputs_embeds(self):
|
| 499 |
+
# ViT does not use inputs_embeds
|
| 500 |
+
pass
|
| 501 |
+
|
| 502 |
+
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
|
| 503 |
+
def test_inputs_embeds_matches_input_ids(self):
|
| 504 |
+
pass
|
| 505 |
+
|
| 506 |
+
@unittest.skip(reason="CANINE does not have a get_input_embeddings() method.")
|
| 507 |
+
def test_model_get_set_embeddings(self):
|
| 508 |
+
pass
|
| 509 |
+
|
| 510 |
+
@unittest.skip(
|
| 511 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 512 |
+
)
|
| 513 |
+
def test_training_gradient_checkpointing(self):
|
| 514 |
+
pass
|
| 515 |
+
|
| 516 |
+
@unittest.skip(
|
| 517 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 518 |
+
)
|
| 519 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 520 |
+
pass
|
| 521 |
+
|
| 522 |
+
@unittest.skip(
|
| 523 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 524 |
+
)
|
| 525 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 526 |
+
pass
|
| 527 |
+
|
| 528 |
+
@slow
|
| 529 |
+
def test_model_from_pretrained(self):
|
| 530 |
+
model_name = "google/canine-s"
|
| 531 |
+
model = CanineModel.from_pretrained(model_name)
|
| 532 |
+
self.assertIsNotNone(model)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
@require_torch
|
| 536 |
+
class CanineModelIntegrationTest(unittest.TestCase):
|
| 537 |
+
@slow
|
| 538 |
+
def test_inference_no_head(self):
|
| 539 |
+
model = CanineModel.from_pretrained("google/canine-s")
|
| 540 |
+
# this one corresponds to the first example of the TydiQA dev set (in Swahili)
|
| 541 |
+
# fmt: off
|
| 542 |
+
input_ids = [57344, 57349, 85, 107, 117, 98, 119, 97, 32, 119, 97, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 111, 114, 105, 32, 110, 105, 32, 107, 105, 97, 115, 105, 32, 103, 97, 110, 105, 63, 57345, 57350, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 111, 114, 105, 32, 44, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 97, 117, 32, 105, 110, 103, 46, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 40, 112, 105, 97, 58, 32, 84, 111, 108, 105, 109, 97, 110, 32, 97, 117, 32, 82, 105, 103, 105, 108, 32, 75, 101, 110, 116, 97, 117, 114, 117, 115, 41, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 105, 110, 97, 121, 111, 110, 103, 39, 97, 97, 32, 115, 97, 110, 97, 32, 107, 97, 116, 105, 107, 97, 32, 97, 110, 103, 97, 32, 121, 97, 32, 107, 117, 115, 105, 110, 105, 32, 107, 119, 101, 110, 121, 101, 32, 107, 117, 110, 100, 105, 110, 121, 111, 116, 97, 32, 121, 97, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 40, 112, 105, 97, 58, 32, 105, 110, 103, 46, 32, 67, 101, 110, 116, 97, 117, 114, 117, 115, 41, 46, 32, 78, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 107, 117, 110, 103, 97, 97, 32, 115, 97, 110, 97, 32, 121, 97, 32, 110, 110, 101, 32, 97, 110, 103, 97, 110, 105, 32, 108, 97, 107, 105, 110, 105, 32, 104, 97, 105, 111, 110, 101, 107, 97, 110, 105, 32, 107, 119, 101, 110, 121, 101, 32, 110, 117, 115, 117, 100, 117, 110, 105, 97, 32, 121, 97, 32, 107, 97, 115, 107, 97, 122, 105, 110, 105, 46, 32, 57351, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 112, 101, 107, 101, 101, 32, 107, 119, 97, 32, 115, 97, 98, 97, 98, 117, 32, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 101, 116, 117, 32, 106, 105, 114, 97, 110, 105, 32, 107, 97, 116, 105, 107, 97, 32, 97, 110, 103, 97, 32, 105, 110, 97, 32, 117, 109, 98, 97, 108, 105, 32, 119, 97, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 50, 46, 32, 73, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 97, 110, 103, 97, 110, 105, 32, 107, 97, 114, 105, 98, 117, 32, 110, 97, 32, 107, 117, 110, 100, 105, 110, 121, 111, 116, 97, 32, 121, 97, 32, 83, 97, 108, 105, 98, 117, 32, 40, 67, 114, 117, 120, 41, 46, 32, 57352, 32, 82, 105, 106, 105, 108, 105, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 40, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 41, 32, 105, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 107, 97, 109, 97, 32, 110, 121, 111, 116, 97, 32, 109, 111, 106, 97, 32, 108, 97, 107, 105, 110, 105, 32, 107, 119, 97, 32, 100, 97, 114, 117, 98, 105, 110, 105, 32, 107, 117, 98, 119, 97, 32, 105, 110, 97, 111, 110, 101, 107, 97, 110, 97, 32, 107, 117, 119, 97, 32, 109, 102, 117, 109, 111, 32, 119, 97, 32, 110, 121, 111, 116, 97, 32, 116, 97, 116, 117, 32, 122, 105, 110, 97, 122, 111, 107, 97, 97, 32, 107, 97, 114, 105, 98, 117, 32, 110, 97, 32, 107, 117, 115, 104, 105, 107, 97, 109, 97, 110, 97, 32, 107, 97, 116, 105, 32, 121, 97, 111, 46, 32, 78, 121, 111, 116, 97, 32, 109, 97, 112, 97, 99, 104, 97, 32, 122, 97, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 65, 32, 110, 97, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 66, 32, 122, 105, 107, 111, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 51, 54, 32, 107, 117, 116, 111, 107, 97, 32, 107, 119, 101, 116, 117, 32, 110, 97, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 116, 97, 116, 117, 32, 65, 108, 112, 104, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 67, 32, 97, 117, 32, 80, 114, 111, 120, 105, 109, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 105, 110, 97, 32, 117, 109, 98, 97, 108, 105, 32, 119, 97, 32, 109, 105, 97, 107, 97, 32, 121, 97, 32, 110, 117, 114, 117, 32, 52, 46, 50, 50, 46, 32, 57353, 32, 80, 114, 111, 120, 105, 109, 97, 32, 67, 101, 110, 116, 97, 117, 114, 105, 32, 40, 121, 97, 97, 110, 105, 32, 110, 121, 111, 116, 97, 32, 121, 97, 32, 75, 97, 110, 116, 97, 114, 117, 115, 105, 32, 105, 108, 105, 121, 111, 32, 107, 97, 114, 105, 98, 117, 32, 122, 97, 105, 100, 105, 32, 110, 97, 115, 105, 41, 32, 105, 109, 101, 103, 117, 110, 100, 117, 108, 105, 119, 97, 32, 107, 117, 119, 97, 32, 110, 97, 32, 115, 97, 121, 97, 114, 105, 32, 109, 111, 106, 97, 46, 32, 86, 105, 112, 105, 109, 111, 32, 118, 105, 110, 97, 118, 121, 111, 112, 97, 116, 105, 107, 97, 110, 97, 32, 104, 97, 100, 105, 32, 115, 97, 115, 97, 32, 122, 105, 110, 97, 111, 110, 121, 101, 115, 104, 97, 32, 117, 119, 101, 122, 101, 107, 97, 110, 111, 32, 109, 107, 117, 98, 119, 97, 32, 121, 97, 32, 107, 119, 97, 109, 98, 97, 32, 115, 97, 121, 97, 114, 105, 32, 104, 105, 105, 32, 110, 105, 32, 121, 97, 32, 109, 119, 97, 109, 98, 97, 32, 40, 107, 97, 109, 97, 32, 100, 117, 110, 105, 97, 32, 121, 101, 116, 117, 44, 32, 77, 105, 114, 105, 104, 105, 32, 97, 117, 32, 90, 117, 104, 117, 114, 97, 41, 32, 110, 97, 32, 105, 110, 97, 119, 101, 122, 97, 32, 107, 117, 119, 97, 32, 110, 97, 32, 97, 110, 103, 97, 104, 101, 119, 97, 44, 32, 116, 101, 110, 97, 32, 107, 97, 116, 105, 107, 97, 32, 117, 112, 101, 111, 32, 119, 97, 32, 106, 111, 116, 111, 32, 117, 110, 97, 111, 114, 117, 104, 117, 115, 117, 32, 107, 117, 119, 101, 112, 111, 32, 107, 119, 97, 32, 117, 104, 97, 105, 46, 32, 91, 49, 93, 57345, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
| 543 |
+
attention_mask = [1 if x != 0 else 0 for x in input_ids]
|
| 544 |
+
token_type_ids = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
|
| 545 |
+
# fmt: on
|
| 546 |
+
input_ids = torch.tensor([input_ids])
|
| 547 |
+
attention_mask = torch.tensor([attention_mask])
|
| 548 |
+
token_type_ids = torch.tensor([token_type_ids])
|
| 549 |
+
outputs = model(input_ids, attention_mask, token_type_ids)
|
| 550 |
+
|
| 551 |
+
# verify sequence output
|
| 552 |
+
expected_shape = torch.Size((1, 2048, 768))
|
| 553 |
+
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
|
| 554 |
+
|
| 555 |
+
expected_slice = torch.tensor(
|
| 556 |
+
[
|
| 557 |
+
[-0.161433131, 0.395568609, 0.0407391489],
|
| 558 |
+
[-0.108025983, 0.362060368, -0.544592619],
|
| 559 |
+
[-0.141537309, 0.180541009, 0.076907],
|
| 560 |
+
]
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
torch.testing.assert_close(outputs.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-2, atol=1e-2)
|
| 564 |
+
|
| 565 |
+
# verify pooled output
|
| 566 |
+
expected_shape = torch.Size((1, 768))
|
| 567 |
+
self.assertEqual(outputs.pooler_output.shape, expected_shape)
|
| 568 |
+
|
| 569 |
+
expected_slice = torch.tensor([-0.884311497, -0.529064834, 0.723164916])
|
| 570 |
+
|
| 571 |
+
torch.testing.assert_close(outputs.pooler_output[0, :3], expected_slice, rtol=1e-2, atol=1e-2)
|
docs/transformers/tests/models/canine/test_tokenization_canine.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 Google AI and HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
from functools import lru_cache
|
| 21 |
+
|
| 22 |
+
from transformers import BatchEncoding, CanineTokenizer
|
| 23 |
+
from transformers.testing_utils import require_tokenizers, require_torch
|
| 24 |
+
from transformers.tokenization_utils import AddedToken
|
| 25 |
+
from transformers.utils import cached_property
|
| 26 |
+
|
| 27 |
+
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 31 |
+
from_pretrained_id = "nielsr/canine-s"
|
| 32 |
+
tokenizer_class = CanineTokenizer
|
| 33 |
+
test_rust_tokenizer = False
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def setUpClass(cls):
|
| 37 |
+
super().setUpClass()
|
| 38 |
+
tokenizer = CanineTokenizer()
|
| 39 |
+
tokenizer.save_pretrained(cls.tmpdirname)
|
| 40 |
+
|
| 41 |
+
@cached_property
|
| 42 |
+
def canine_tokenizer(self):
|
| 43 |
+
return CanineTokenizer.from_pretrained("google/canine-s")
|
| 44 |
+
|
| 45 |
+
@classmethod
|
| 46 |
+
@use_cache_if_possible
|
| 47 |
+
@lru_cache(maxsize=64)
|
| 48 |
+
def get_tokenizer(cls, pretrained_name=None, **kwargs) -> CanineTokenizer:
|
| 49 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 50 |
+
tokenizer = cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
| 51 |
+
tokenizer._unicode_vocab_size = 1024
|
| 52 |
+
return tokenizer
|
| 53 |
+
|
| 54 |
+
@require_torch
|
| 55 |
+
def test_prepare_batch_integration(self):
|
| 56 |
+
tokenizer = self.canine_tokenizer
|
| 57 |
+
src_text = ["Life is like a box of chocolates.", "You never know what you're gonna get."]
|
| 58 |
+
expected_src_tokens = [57344, 76, 105, 102, 101, 32, 105, 115, 32, 108, 105, 107, 101, 32, 97, 32, 98, 111, 120, 32, 111, 102, 32, 99, 104, 111, 99, 111, 108, 97, 116, 101, 115, 46, 57345, 0, 0, 0, 0] # fmt: skip
|
| 59 |
+
batch = tokenizer(src_text, padding=True, return_tensors="pt")
|
| 60 |
+
self.assertIsInstance(batch, BatchEncoding)
|
| 61 |
+
|
| 62 |
+
result = list(batch.input_ids.numpy()[0])
|
| 63 |
+
|
| 64 |
+
self.assertListEqual(expected_src_tokens, result)
|
| 65 |
+
|
| 66 |
+
self.assertEqual((2, 39), batch.input_ids.shape)
|
| 67 |
+
self.assertEqual((2, 39), batch.attention_mask.shape)
|
| 68 |
+
|
| 69 |
+
@require_torch
|
| 70 |
+
def test_encoding_keys(self):
|
| 71 |
+
tokenizer = self.canine_tokenizer
|
| 72 |
+
src_text = ["Once there was a man.", "He wrote a test in HuggingFace Transformers."]
|
| 73 |
+
batch = tokenizer(src_text, padding=True, return_tensors="pt")
|
| 74 |
+
# check if input_ids, attention_mask and token_type_ids are returned
|
| 75 |
+
self.assertIn("input_ids", batch)
|
| 76 |
+
self.assertIn("attention_mask", batch)
|
| 77 |
+
self.assertIn("token_type_ids", batch)
|
| 78 |
+
|
| 79 |
+
@require_torch
|
| 80 |
+
def test_max_length_integration(self):
|
| 81 |
+
tokenizer = self.canine_tokenizer
|
| 82 |
+
tgt_text = [
|
| 83 |
+
"What's the weater?",
|
| 84 |
+
"It's about 25 degrees.",
|
| 85 |
+
]
|
| 86 |
+
targets = tokenizer(
|
| 87 |
+
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt"
|
| 88 |
+
)
|
| 89 |
+
self.assertEqual(32, targets["input_ids"].shape[1])
|
| 90 |
+
|
| 91 |
+
# cannot use default save_and_load_tokenizer test method because tokenizer has no vocab
|
| 92 |
+
def test_save_and_load_tokenizer(self):
|
| 93 |
+
# safety check on max_len default value so we are sure the test works
|
| 94 |
+
tokenizers = self.get_tokenizers()
|
| 95 |
+
for tokenizer in tokenizers:
|
| 96 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 97 |
+
self.assertNotEqual(tokenizer.model_max_length, 42)
|
| 98 |
+
|
| 99 |
+
# Now let's start the test
|
| 100 |
+
tokenizers = self.get_tokenizers()
|
| 101 |
+
for tokenizer in tokenizers:
|
| 102 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 103 |
+
# Isolate this from the other tests because we save additional tokens/etc
|
| 104 |
+
tmpdirname = tempfile.mkdtemp()
|
| 105 |
+
|
| 106 |
+
sample_text = " He is very happy, UNwant\u00e9d,running"
|
| 107 |
+
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
| 108 |
+
tokenizer.save_pretrained(tmpdirname)
|
| 109 |
+
|
| 110 |
+
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
| 111 |
+
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
| 112 |
+
self.assertListEqual(before_tokens, after_tokens)
|
| 113 |
+
|
| 114 |
+
shutil.rmtree(tmpdirname)
|
| 115 |
+
|
| 116 |
+
tokenizers = self.get_tokenizers(model_max_length=42)
|
| 117 |
+
for tokenizer in tokenizers:
|
| 118 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 119 |
+
# Isolate this from the other tests because we save additional tokens/etc
|
| 120 |
+
tmpdirname = tempfile.mkdtemp()
|
| 121 |
+
|
| 122 |
+
sample_text = " He is very happy, UNwant\u00e9d,running"
|
| 123 |
+
|
| 124 |
+
additional_special_tokens = tokenizer.additional_special_tokens
|
| 125 |
+
|
| 126 |
+
# We can add a new special token for Canine as follows:
|
| 127 |
+
new_additional_special_token = chr(0xE007)
|
| 128 |
+
additional_special_tokens.append(new_additional_special_token)
|
| 129 |
+
tokenizer.add_special_tokens(
|
| 130 |
+
{"additional_special_tokens": additional_special_tokens}, replace_additional_special_tokens=False
|
| 131 |
+
)
|
| 132 |
+
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
|
| 133 |
+
tokenizer.save_pretrained(tmpdirname)
|
| 134 |
+
|
| 135 |
+
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
| 136 |
+
after_tokens = after_tokenizer.encode(sample_text, add_special_tokens=False)
|
| 137 |
+
self.assertListEqual(before_tokens, after_tokens)
|
| 138 |
+
self.assertIn(new_additional_special_token, after_tokenizer.additional_special_tokens)
|
| 139 |
+
self.assertEqual(after_tokenizer.model_max_length, 42)
|
| 140 |
+
|
| 141 |
+
tokenizer = tokenizer.__class__.from_pretrained(tmpdirname, model_max_length=43)
|
| 142 |
+
self.assertEqual(tokenizer.model_max_length, 43)
|
| 143 |
+
|
| 144 |
+
shutil.rmtree(tmpdirname)
|
| 145 |
+
|
| 146 |
+
def test_add_special_tokens(self):
|
| 147 |
+
tokenizers = self.get_tokenizers(do_lower_case=False)
|
| 148 |
+
for tokenizer in tokenizers:
|
| 149 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 150 |
+
input_text, ids = self.get_clean_sequence(tokenizer)
|
| 151 |
+
|
| 152 |
+
# a special token for Canine can be defined as follows:
|
| 153 |
+
SPECIAL_TOKEN = 0xE005
|
| 154 |
+
special_token = chr(SPECIAL_TOKEN)
|
| 155 |
+
|
| 156 |
+
tokenizer.add_special_tokens({"cls_token": special_token})
|
| 157 |
+
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
| 158 |
+
self.assertEqual(len(encoded_special_token), 1)
|
| 159 |
+
|
| 160 |
+
text = tokenizer.decode(ids + encoded_special_token, clean_up_tokenization_spaces=False)
|
| 161 |
+
encoded = tokenizer.encode(text, add_special_tokens=False)
|
| 162 |
+
|
| 163 |
+
input_encoded = tokenizer.encode(input_text, add_special_tokens=False)
|
| 164 |
+
special_token_id = tokenizer.encode(special_token, add_special_tokens=False)
|
| 165 |
+
self.assertEqual(encoded, input_encoded + special_token_id)
|
| 166 |
+
|
| 167 |
+
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
|
| 168 |
+
self.assertTrue(special_token not in decoded)
|
| 169 |
+
|
| 170 |
+
def test_tokenize_special_tokens(self):
|
| 171 |
+
tokenizers = self.get_tokenizers(do_lower_case=True)
|
| 172 |
+
for tokenizer in tokenizers:
|
| 173 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 174 |
+
SPECIAL_TOKEN_1 = chr(0xE005)
|
| 175 |
+
SPECIAL_TOKEN_2 = chr(0xE006)
|
| 176 |
+
tokenizer.add_tokens([SPECIAL_TOKEN_1], special_tokens=True)
|
| 177 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [SPECIAL_TOKEN_2]})
|
| 178 |
+
|
| 179 |
+
token_1 = tokenizer.tokenize(SPECIAL_TOKEN_1)
|
| 180 |
+
token_2 = tokenizer.tokenize(SPECIAL_TOKEN_2)
|
| 181 |
+
|
| 182 |
+
self.assertEqual(len(token_1), 1)
|
| 183 |
+
self.assertEqual(len(token_2), 1)
|
| 184 |
+
self.assertEqual(token_1[0], SPECIAL_TOKEN_1)
|
| 185 |
+
self.assertEqual(token_2[0], SPECIAL_TOKEN_2)
|
| 186 |
+
|
| 187 |
+
@require_tokenizers
|
| 188 |
+
def test_added_token_serializable(self):
|
| 189 |
+
tokenizers = self.get_tokenizers(do_lower_case=False)
|
| 190 |
+
for tokenizer in tokenizers:
|
| 191 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 192 |
+
# a special token for Canine can be defined as follows:
|
| 193 |
+
NEW_TOKEN = 0xE006
|
| 194 |
+
new_token = chr(NEW_TOKEN)
|
| 195 |
+
|
| 196 |
+
new_token = AddedToken(new_token, lstrip=True)
|
| 197 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
|
| 198 |
+
|
| 199 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 200 |
+
tokenizer.save_pretrained(tmp_dir_name)
|
| 201 |
+
tokenizer.from_pretrained(tmp_dir_name)
|
| 202 |
+
|
| 203 |
+
def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self):
|
| 204 |
+
tokenizer_list = []
|
| 205 |
+
if self.test_slow_tokenizer:
|
| 206 |
+
tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))
|
| 207 |
+
|
| 208 |
+
if self.test_rust_tokenizer:
|
| 209 |
+
tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))
|
| 210 |
+
|
| 211 |
+
for tokenizer_class, tokenizer_utils in tokenizer_list:
|
| 212 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 213 |
+
tokenizer_utils.save_pretrained(tmp_dir)
|
| 214 |
+
|
| 215 |
+
with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file:
|
| 216 |
+
special_tokens_map = json.load(json_file)
|
| 217 |
+
|
| 218 |
+
with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file:
|
| 219 |
+
tokenizer_config = json.load(json_file)
|
| 220 |
+
|
| 221 |
+
# a special token for Canine can be defined as follows:
|
| 222 |
+
NEW_TOKEN = 0xE006
|
| 223 |
+
new_token_1 = chr(NEW_TOKEN)
|
| 224 |
+
|
| 225 |
+
special_tokens_map["additional_special_tokens"] = [new_token_1]
|
| 226 |
+
tokenizer_config["additional_special_tokens"] = [new_token_1]
|
| 227 |
+
|
| 228 |
+
with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile:
|
| 229 |
+
json.dump(special_tokens_map, outfile)
|
| 230 |
+
with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile:
|
| 231 |
+
json.dump(tokenizer_config, outfile)
|
| 232 |
+
|
| 233 |
+
# the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes
|
| 234 |
+
# into account the new value of additional_special_tokens given in the "tokenizer_config.json" and
|
| 235 |
+
# "special_tokens_map.json" files
|
| 236 |
+
tokenizer_without_change_in_init = tokenizer_class.from_pretrained(tmp_dir, extra_ids=0)
|
| 237 |
+
self.assertIn(new_token_1, tokenizer_without_change_in_init.additional_special_tokens)
|
| 238 |
+
# self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # ByT5Tokenization no vocab
|
| 239 |
+
self.assertEqual(
|
| 240 |
+
[new_token_1],
|
| 241 |
+
tokenizer_without_change_in_init.convert_ids_to_tokens(
|
| 242 |
+
tokenizer_without_change_in_init.convert_tokens_to_ids([new_token_1])
|
| 243 |
+
),
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
NEW_TOKEN = 0xE007
|
| 247 |
+
new_token_2 = chr(NEW_TOKEN)
|
| 248 |
+
# Now we test that we can change the value of additional_special_tokens in the from_pretrained
|
| 249 |
+
new_added_tokens = [AddedToken(new_token_2, lstrip=True)]
|
| 250 |
+
tokenizer = tokenizer_class.from_pretrained(
|
| 251 |
+
tmp_dir, additional_special_tokens=new_added_tokens, extra_ids=0
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
self.assertIn(new_token_2, tokenizer.additional_special_tokens)
|
| 255 |
+
# self.assertIn(new_token_2,tokenizer.get_vocab()) # ByT5Tokenization no vocab
|
| 256 |
+
self.assertEqual(
|
| 257 |
+
[new_token_2], tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids([new_token_2]))
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
@require_tokenizers
|
| 261 |
+
def test_encode_decode_with_spaces(self):
|
| 262 |
+
tokenizers = self.get_tokenizers(do_lower_case=False)
|
| 263 |
+
for tokenizer in tokenizers:
|
| 264 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 265 |
+
input = "hello world"
|
| 266 |
+
if self.space_between_special_tokens:
|
| 267 |
+
output = "[CLS] hello world [SEP]"
|
| 268 |
+
else:
|
| 269 |
+
output = input
|
| 270 |
+
encoded = tokenizer.encode(input, add_special_tokens=False)
|
| 271 |
+
decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
|
| 272 |
+
self.assertIn(decoded, [output, output.lower()])
|
| 273 |
+
|
| 274 |
+
# cannot use default `test_tokenizers_common_ids_setters` method because tokenizer has no vocab
|
| 275 |
+
def test_tokenizers_common_ids_setters(self):
|
| 276 |
+
tokenizers = self.get_tokenizers()
|
| 277 |
+
for tokenizer in tokenizers:
|
| 278 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 279 |
+
attributes_list = [
|
| 280 |
+
"bos_token",
|
| 281 |
+
"eos_token",
|
| 282 |
+
"unk_token",
|
| 283 |
+
"sep_token",
|
| 284 |
+
"pad_token",
|
| 285 |
+
"cls_token",
|
| 286 |
+
"mask_token",
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
token_to_test_setters = "a"
|
| 290 |
+
token_id_to_test_setters = ord(token_to_test_setters)
|
| 291 |
+
|
| 292 |
+
for attr in attributes_list:
|
| 293 |
+
setattr(tokenizer, attr + "_id", None)
|
| 294 |
+
self.assertEqual(getattr(tokenizer, attr), None)
|
| 295 |
+
self.assertEqual(getattr(tokenizer, attr + "_id"), None)
|
| 296 |
+
|
| 297 |
+
setattr(tokenizer, attr + "_id", token_id_to_test_setters)
|
| 298 |
+
self.assertEqual(getattr(tokenizer, attr), token_to_test_setters)
|
| 299 |
+
self.assertEqual(getattr(tokenizer, attr + "_id"), token_id_to_test_setters)
|
| 300 |
+
|
| 301 |
+
setattr(tokenizer, "additional_special_tokens_ids", [])
|
| 302 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [])
|
| 303 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [])
|
| 304 |
+
|
| 305 |
+
additional_special_token_id = 0xE006
|
| 306 |
+
additional_special_token = chr(additional_special_token_id)
|
| 307 |
+
setattr(tokenizer, "additional_special_tokens_ids", [additional_special_token_id])
|
| 308 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens"), [additional_special_token])
|
| 309 |
+
self.assertListEqual(getattr(tokenizer, "additional_special_tokens_ids"), [additional_special_token_id])
|
| 310 |
+
|
| 311 |
+
@unittest.skip(reason="tokenizer has a fixed vocab_size (namely all possible unicode code points)")
|
| 312 |
+
def test_add_tokens_tokenizer(self):
|
| 313 |
+
pass
|
| 314 |
+
|
| 315 |
+
# CanineTokenizer does not support do_lower_case = True, as each character has its own Unicode code point
|
| 316 |
+
# ("b" and "B" for example have different Unicode code points)
|
| 317 |
+
@unittest.skip(reason="CanineTokenizer does not support do_lower_case = True")
|
| 318 |
+
def test_added_tokens_do_lower_case(self):
|
| 319 |
+
pass
|
| 320 |
+
|
| 321 |
+
@unittest.skip(reason="CanineModel does not support the get_input_embeddings nor the get_vocab method")
|
| 322 |
+
def test_np_encode_plus_sent_to_model(self):
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
@unittest.skip(reason="CanineModel does not support the get_input_embeddings nor the get_vocab method")
|
| 326 |
+
def test_torch_encode_plus_sent_to_model(self):
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
@unittest.skip(reason="CanineTokenizer does not have vocabulary")
|
| 330 |
+
def test_get_vocab(self):
|
| 331 |
+
pass
|
| 332 |
+
|
| 333 |
+
@unittest.skip(reason="inputs cannot be pretokenized since ids depend on whole input string")
|
| 334 |
+
def test_pretokenized_inputs(self):
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
@unittest.skip(reason="CanineTokenizer does not have vocabulary")
|
| 338 |
+
def test_conversion_reversible(self):
|
| 339 |
+
pass
|
docs/transformers/tests/models/chameleon/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/chameleon/test_image_processing_chameleon.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 HuggingFace Inc.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import unittest
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
from transformers.testing_utils import require_torch, require_vision
|
| 20 |
+
from transformers.utils import is_torch_available, is_vision_available
|
| 21 |
+
|
| 22 |
+
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if is_torch_available():
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
if is_vision_available():
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
from transformers import ChameleonImageProcessor
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ChameleonImageProcessingTester:
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
parent,
|
| 38 |
+
batch_size=7,
|
| 39 |
+
num_channels=3,
|
| 40 |
+
image_size=18,
|
| 41 |
+
min_resolution=30,
|
| 42 |
+
max_resolution=200,
|
| 43 |
+
do_resize=True,
|
| 44 |
+
size=None,
|
| 45 |
+
do_center_crop=True,
|
| 46 |
+
crop_size=None,
|
| 47 |
+
do_normalize=True,
|
| 48 |
+
image_mean=[1.0, 1.0, 1.0],
|
| 49 |
+
image_std=[1.0, 1.0, 1.0],
|
| 50 |
+
do_convert_rgb=True,
|
| 51 |
+
):
|
| 52 |
+
size = size if size is not None else {"shortest_edge": 18}
|
| 53 |
+
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
| 54 |
+
self.parent = parent
|
| 55 |
+
self.batch_size = batch_size
|
| 56 |
+
self.num_channels = num_channels
|
| 57 |
+
self.image_size = image_size
|
| 58 |
+
self.min_resolution = min_resolution
|
| 59 |
+
self.max_resolution = max_resolution
|
| 60 |
+
self.do_resize = do_resize
|
| 61 |
+
self.size = size
|
| 62 |
+
self.do_center_crop = do_center_crop
|
| 63 |
+
self.crop_size = crop_size
|
| 64 |
+
self.do_normalize = do_normalize
|
| 65 |
+
self.image_mean = image_mean
|
| 66 |
+
self.image_std = image_std
|
| 67 |
+
self.do_convert_rgb = do_convert_rgb
|
| 68 |
+
|
| 69 |
+
def prepare_image_processor_dict(self):
|
| 70 |
+
return {
|
| 71 |
+
"do_resize": self.do_resize,
|
| 72 |
+
"size": self.size,
|
| 73 |
+
"do_center_crop": self.do_center_crop,
|
| 74 |
+
"crop_size": self.crop_size,
|
| 75 |
+
"do_normalize": self.do_normalize,
|
| 76 |
+
"image_mean": self.image_mean,
|
| 77 |
+
"image_std": self.image_std,
|
| 78 |
+
"do_convert_rgb": self.do_convert_rgb,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.expected_output_image_shape
|
| 82 |
+
def expected_output_image_shape(self, images):
|
| 83 |
+
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
| 84 |
+
|
| 85 |
+
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
|
| 86 |
+
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
| 87 |
+
return prepare_image_inputs(
|
| 88 |
+
batch_size=self.batch_size,
|
| 89 |
+
num_channels=self.num_channels,
|
| 90 |
+
min_resolution=self.min_resolution,
|
| 91 |
+
max_resolution=self.max_resolution,
|
| 92 |
+
equal_resolution=equal_resolution,
|
| 93 |
+
numpify=numpify,
|
| 94 |
+
torchify=torchify,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@require_torch
|
| 99 |
+
@require_vision
|
| 100 |
+
class ChameleonImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
| 101 |
+
image_processing_class = ChameleonImageProcessor if is_vision_available() else None
|
| 102 |
+
|
| 103 |
+
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Chameleon
|
| 104 |
+
def setUp(self):
|
| 105 |
+
super().setUp()
|
| 106 |
+
self.image_processor_tester = ChameleonImageProcessingTester(self)
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
|
| 110 |
+
def image_processor_dict(self):
|
| 111 |
+
return self.image_processor_tester.prepare_image_processor_dict()
|
| 112 |
+
|
| 113 |
+
def test_image_processor_properties(self):
|
| 114 |
+
image_processing = self.image_processing_class(**self.image_processor_dict)
|
| 115 |
+
self.assertTrue(hasattr(image_processing, "do_resize"))
|
| 116 |
+
self.assertTrue(hasattr(image_processing, "size"))
|
| 117 |
+
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
| 118 |
+
self.assertTrue(hasattr(image_processing, "center_crop"))
|
| 119 |
+
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
| 120 |
+
self.assertTrue(hasattr(image_processing, "image_mean"))
|
| 121 |
+
self.assertTrue(hasattr(image_processing, "image_std"))
|
| 122 |
+
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
| 123 |
+
|
| 124 |
+
def test_image_processor_from_dict_with_kwargs(self):
|
| 125 |
+
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
| 126 |
+
self.assertEqual(image_processor.size, {"shortest_edge": 18})
|
| 127 |
+
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
| 128 |
+
|
| 129 |
+
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
| 130 |
+
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
| 131 |
+
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
| 132 |
+
|
| 133 |
+
def test_call_pil(self):
|
| 134 |
+
# Initialize image_processing
|
| 135 |
+
image_processing = self.image_processing_class(**self.image_processor_dict)
|
| 136 |
+
# create random PIL images
|
| 137 |
+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
| 138 |
+
for image in image_inputs:
|
| 139 |
+
self.assertIsInstance(image, Image.Image)
|
| 140 |
+
|
| 141 |
+
# Test not batched input
|
| 142 |
+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
| 143 |
+
expected_output_image_shape = (1, 3, 18, 18)
|
| 144 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 145 |
+
|
| 146 |
+
# Test batched
|
| 147 |
+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
| 148 |
+
expected_output_image_shape = (7, 3, 18, 18)
|
| 149 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 150 |
+
|
| 151 |
+
def test_call_numpy(self):
|
| 152 |
+
# Initialize image_processing
|
| 153 |
+
image_processing = self.image_processing_class(**self.image_processor_dict)
|
| 154 |
+
# create random numpy tensors
|
| 155 |
+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
|
| 156 |
+
for image in image_inputs:
|
| 157 |
+
self.assertIsInstance(image, np.ndarray)
|
| 158 |
+
|
| 159 |
+
# Test not batched input
|
| 160 |
+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
| 161 |
+
expected_output_image_shape = (1, 3, 18, 18)
|
| 162 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 163 |
+
|
| 164 |
+
# Test batched
|
| 165 |
+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
| 166 |
+
expected_output_image_shape = (7, 3, 18, 18)
|
| 167 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 168 |
+
|
| 169 |
+
def test_call_pytorch(self):
|
| 170 |
+
# Initialize image_processing
|
| 171 |
+
image_processing = self.image_processing_class(**self.image_processor_dict)
|
| 172 |
+
# create random PyTorch tensors
|
| 173 |
+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
|
| 174 |
+
|
| 175 |
+
for image in image_inputs:
|
| 176 |
+
self.assertIsInstance(image, torch.Tensor)
|
| 177 |
+
|
| 178 |
+
# Test not batched input
|
| 179 |
+
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
|
| 180 |
+
expected_output_image_shape = (1, 3, 18, 18)
|
| 181 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 182 |
+
|
| 183 |
+
# Test batched
|
| 184 |
+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
| 185 |
+
expected_output_image_shape = (7, 3, 18, 18)
|
| 186 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 187 |
+
|
| 188 |
+
def test_nested_input(self):
|
| 189 |
+
image_processing = self.image_processing_class(**self.image_processor_dict)
|
| 190 |
+
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
|
| 191 |
+
|
| 192 |
+
# Test batched as a list of images
|
| 193 |
+
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
|
| 194 |
+
expected_output_image_shape = (7, 3, 18, 18)
|
| 195 |
+
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
| 196 |
+
|
| 197 |
+
# Test batched as a nested list of images, where each sublist is one batch
|
| 198 |
+
image_inputs_nested = [image_inputs[:3], image_inputs[3:]]
|
| 199 |
+
encoded_images_nested = image_processing(image_inputs_nested, return_tensors="pt").pixel_values
|
| 200 |
+
expected_output_image_shape = (7, 3, 18, 18)
|
| 201 |
+
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
|
| 202 |
+
|
| 203 |
+
# Image processor should return same pixel values, independently of input format
|
| 204 |
+
self.assertTrue((encoded_images_nested == encoded_images).all())
|
docs/transformers/tests/models/chameleon/test_modeling_chameleon.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch chameleon model."""
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
import requests
|
| 20 |
+
from parameterized import parameterized
|
| 21 |
+
|
| 22 |
+
from transformers import ChameleonConfig, is_torch_available, is_vision_available, set_seed
|
| 23 |
+
from transformers.testing_utils import (
|
| 24 |
+
require_bitsandbytes,
|
| 25 |
+
require_read_token,
|
| 26 |
+
require_torch,
|
| 27 |
+
slow,
|
| 28 |
+
torch_device,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from ...generation.test_utils import GenerationTesterMixin
|
| 32 |
+
from ...test_configuration_common import ConfigTester
|
| 33 |
+
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
| 34 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_vision_available():
|
| 38 |
+
from PIL import Image
|
| 39 |
+
|
| 40 |
+
if is_torch_available():
|
| 41 |
+
import torch
|
| 42 |
+
|
| 43 |
+
from transformers import (
|
| 44 |
+
ChameleonForConditionalGeneration,
|
| 45 |
+
ChameleonModel,
|
| 46 |
+
ChameleonProcessor,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ChameleonModelTester:
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
parent,
|
| 54 |
+
batch_size=13,
|
| 55 |
+
seq_length=35,
|
| 56 |
+
is_training=False,
|
| 57 |
+
use_input_mask=True,
|
| 58 |
+
use_labels=True,
|
| 59 |
+
vocab_size=99,
|
| 60 |
+
image_token_id=4,
|
| 61 |
+
hidden_size=32,
|
| 62 |
+
num_hidden_layers=2,
|
| 63 |
+
num_attention_heads=2,
|
| 64 |
+
num_key_value_heads=2,
|
| 65 |
+
intermediate_size=37,
|
| 66 |
+
hidden_act="gelu",
|
| 67 |
+
hidden_dropout_prob=0.1,
|
| 68 |
+
attention_probs_dropout_prob=0.1,
|
| 69 |
+
max_position_embeddings=512,
|
| 70 |
+
type_vocab_size=16,
|
| 71 |
+
type_sequence_label_size=2,
|
| 72 |
+
initializer_range=0.02,
|
| 73 |
+
num_labels=3,
|
| 74 |
+
num_choices=4,
|
| 75 |
+
pad_token_id=0,
|
| 76 |
+
vq_num_embeds=5,
|
| 77 |
+
vq_embed_dim=5,
|
| 78 |
+
vq_channel_multiplier=[1, 4],
|
| 79 |
+
vq_img_token_start_id=10, # has to be less than vocab size when added with vq_num_embeds
|
| 80 |
+
scope=None,
|
| 81 |
+
):
|
| 82 |
+
self.parent = parent
|
| 83 |
+
self.batch_size = batch_size
|
| 84 |
+
self.seq_length = seq_length
|
| 85 |
+
self.is_training = is_training
|
| 86 |
+
self.use_input_mask = use_input_mask
|
| 87 |
+
self.use_labels = use_labels
|
| 88 |
+
self.vocab_size = vocab_size
|
| 89 |
+
self.image_token_id = image_token_id
|
| 90 |
+
self.hidden_size = hidden_size
|
| 91 |
+
self.num_hidden_layers = num_hidden_layers
|
| 92 |
+
self.num_attention_heads = num_attention_heads
|
| 93 |
+
self.num_key_value_heads = num_key_value_heads
|
| 94 |
+
self.intermediate_size = intermediate_size
|
| 95 |
+
self.hidden_act = hidden_act
|
| 96 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 97 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 98 |
+
self.max_position_embeddings = max_position_embeddings
|
| 99 |
+
self.type_vocab_size = type_vocab_size
|
| 100 |
+
self.type_sequence_label_size = type_sequence_label_size
|
| 101 |
+
self.initializer_range = initializer_range
|
| 102 |
+
self.num_labels = num_labels
|
| 103 |
+
self.num_choices = num_choices
|
| 104 |
+
self.pad_token_id = pad_token_id
|
| 105 |
+
self.scope = scope
|
| 106 |
+
self.vq_num_embeds = vq_num_embeds
|
| 107 |
+
self.vq_embed_dim = vq_embed_dim
|
| 108 |
+
self.vq_channel_multiplier = vq_channel_multiplier
|
| 109 |
+
self.vq_img_token_start_id = vq_img_token_start_id
|
| 110 |
+
|
| 111 |
+
def prepare_config_and_inputs(self):
|
| 112 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 113 |
+
|
| 114 |
+
input_mask = None
|
| 115 |
+
if self.use_input_mask:
|
| 116 |
+
input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
| 117 |
+
|
| 118 |
+
sequence_labels = None
|
| 119 |
+
token_labels = None
|
| 120 |
+
choice_labels = None
|
| 121 |
+
if self.use_labels:
|
| 122 |
+
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
| 123 |
+
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
| 124 |
+
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
| 125 |
+
|
| 126 |
+
config = self.get_config()
|
| 127 |
+
|
| 128 |
+
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 129 |
+
|
| 130 |
+
def get_config(self):
|
| 131 |
+
# create dummy vocab map for image2bpe mapping if it needs remapping
|
| 132 |
+
# we assume that vocab size is big enough to account for image tokens somewhere in the beginning
|
| 133 |
+
# same way as in real ckpt, when img tokens are in first half of embeds
|
| 134 |
+
# we will need "vq_num_embeds" amount of tokens
|
| 135 |
+
|
| 136 |
+
vocab_map = {i: chr(i) for i in range(self.vocab_size)}
|
| 137 |
+
vocab_map[self.image_token_id] = "<image>"
|
| 138 |
+
start = self.vq_img_token_start_id
|
| 139 |
+
end = self.vq_img_token_start_id + self.vq_num_embeds
|
| 140 |
+
for i in range(start, end):
|
| 141 |
+
image_token_infix = "".join(chr(ord("A") + int(c)) for c in str(i))
|
| 142 |
+
# dummy str for each image token, anything starting with IMGIMG
|
| 143 |
+
vocab_map[i] = f"IMGIMG{image_token_infix}Z"
|
| 144 |
+
|
| 145 |
+
return ChameleonConfig(
|
| 146 |
+
vocab_size=self.vocab_size,
|
| 147 |
+
hidden_size=self.hidden_size,
|
| 148 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 149 |
+
num_attention_heads=self.num_attention_heads,
|
| 150 |
+
num_key_value_heads=self.num_key_value_heads,
|
| 151 |
+
intermediate_size=self.intermediate_size,
|
| 152 |
+
hidden_act=self.hidden_act,
|
| 153 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 154 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 155 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 156 |
+
type_vocab_size=self.type_vocab_size,
|
| 157 |
+
is_decoder=False,
|
| 158 |
+
initializer_range=self.initializer_range,
|
| 159 |
+
pad_token_id=self.pad_token_id,
|
| 160 |
+
vocabulary_map={v: k for k, v in vocab_map.items()},
|
| 161 |
+
vq_config=self.get_vq_config(),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def get_vq_config(self):
|
| 165 |
+
return {
|
| 166 |
+
"embed_dim": self.vq_embed_dim,
|
| 167 |
+
"num_embeddings": self.vq_num_embeds,
|
| 168 |
+
"latent_channels": self.vq_embed_dim,
|
| 169 |
+
"in_channels": 3,
|
| 170 |
+
"base_channels": 32, # we have a GroupNorm of 32 groups, so can't do less
|
| 171 |
+
"channel_multiplier": self.vq_channel_multiplier,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
|
| 175 |
+
model = ChameleonModel(config=config)
|
| 176 |
+
model.to(torch_device)
|
| 177 |
+
model.eval()
|
| 178 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 179 |
+
result = model(input_ids)
|
| 180 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 181 |
+
|
| 182 |
+
def prepare_config_and_inputs_for_common(self):
|
| 183 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 184 |
+
(
|
| 185 |
+
config,
|
| 186 |
+
input_ids,
|
| 187 |
+
input_mask,
|
| 188 |
+
sequence_labels,
|
| 189 |
+
token_labels,
|
| 190 |
+
choice_labels,
|
| 191 |
+
) = config_and_inputs
|
| 192 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 193 |
+
return config, inputs_dict
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@require_torch
|
| 197 |
+
class ChameleonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 198 |
+
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
|
| 199 |
+
pipeline_model_mapping = (
|
| 200 |
+
{
|
| 201 |
+
"feature-extraction": ChameleonModel,
|
| 202 |
+
"text-generation": ChameleonForConditionalGeneration,
|
| 203 |
+
}
|
| 204 |
+
if is_torch_available()
|
| 205 |
+
else {}
|
| 206 |
+
)
|
| 207 |
+
test_headmasking = False
|
| 208 |
+
test_pruning = False
|
| 209 |
+
fx_compatible = False
|
| 210 |
+
|
| 211 |
+
def setUp(self):
|
| 212 |
+
self.model_tester = ChameleonModelTester(self)
|
| 213 |
+
self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37)
|
| 214 |
+
|
| 215 |
+
def test_config(self):
|
| 216 |
+
self.config_tester.run_common_tests()
|
| 217 |
+
|
| 218 |
+
def test_model(self):
|
| 219 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 220 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 221 |
+
|
| 222 |
+
@parameterized.expand([("linear",), ("dynamic",)])
|
| 223 |
+
def test_model_rope_scaling(self, scaling_type):
|
| 224 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 225 |
+
short_input = ids_tensor([1, 10], config.vocab_size)
|
| 226 |
+
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
|
| 227 |
+
|
| 228 |
+
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
| 229 |
+
original_model = ChameleonModel(config)
|
| 230 |
+
original_model.to(torch_device)
|
| 231 |
+
original_model.eval()
|
| 232 |
+
original_short_output = original_model(short_input).last_hidden_state
|
| 233 |
+
original_long_output = original_model(long_input).last_hidden_state
|
| 234 |
+
|
| 235 |
+
set_seed(42) # Fixed seed at init time so the two models get the same random weights
|
| 236 |
+
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
|
| 237 |
+
scaled_model = ChameleonModel(config)
|
| 238 |
+
scaled_model.to(torch_device)
|
| 239 |
+
scaled_model.eval()
|
| 240 |
+
scaled_short_output = scaled_model(short_input).last_hidden_state
|
| 241 |
+
scaled_long_output = scaled_model(long_input).last_hidden_state
|
| 242 |
+
|
| 243 |
+
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
|
| 244 |
+
# maximum sequence length, so the outputs for the short input should match.
|
| 245 |
+
if scaling_type == "dynamic":
|
| 246 |
+
torch.testing.assert_close(original_short_output, scaled_short_output, rtol=1e-5, atol=1e-5)
|
| 247 |
+
else:
|
| 248 |
+
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
|
| 249 |
+
|
| 250 |
+
# The output should be different for long inputs
|
| 251 |
+
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
|
| 252 |
+
|
| 253 |
+
@unittest.skip("Chameleon forces some token ids to be -inf!")
|
| 254 |
+
def test_batching_equivalence(self):
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
|
| 258 |
+
def test_model_is_small(self):
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class ChameleonVision2SeqModelTester(ChameleonModelTester):
|
| 263 |
+
def __init__(self, parent, image_size=10, **kwargs):
|
| 264 |
+
super().__init__(parent, **kwargs)
|
| 265 |
+
self.image_size = image_size
|
| 266 |
+
self.image_seq_length = 25
|
| 267 |
+
|
| 268 |
+
def prepare_config_and_inputs(self):
|
| 269 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 270 |
+
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
| 271 |
+
input_ids[:, : self.image_seq_length] = self.image_token_id
|
| 272 |
+
attention_mask = torch.tril(torch.ones_like(input_ids).to(torch_device))
|
| 273 |
+
pixel_values = floats_tensor([self.batch_size, 3, self.image_size, self.image_size])
|
| 274 |
+
|
| 275 |
+
config = self.get_config()
|
| 276 |
+
|
| 277 |
+
return config, input_ids, attention_mask, pixel_values
|
| 278 |
+
|
| 279 |
+
def prepare_config_and_inputs_for_common(self):
|
| 280 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 281 |
+
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
| 282 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
| 283 |
+
return config, inputs_dict
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
@require_torch
|
| 287 |
+
class ChameleonVision2SeqModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
| 288 |
+
all_model_classes = (ChameleonModel, ChameleonForConditionalGeneration) if is_torch_available() else ()
|
| 289 |
+
pipeline_model_mapping = (
|
| 290 |
+
{
|
| 291 |
+
"image-text-to-text": ChameleonForConditionalGeneration,
|
| 292 |
+
}
|
| 293 |
+
if is_torch_available()
|
| 294 |
+
else {}
|
| 295 |
+
)
|
| 296 |
+
test_headmasking = False
|
| 297 |
+
test_pruning = False
|
| 298 |
+
fx_compatible = False
|
| 299 |
+
|
| 300 |
+
def setUp(self):
|
| 301 |
+
self.model_tester = ChameleonVision2SeqModelTester(self)
|
| 302 |
+
self.config_tester = ConfigTester(self, config_class=ChameleonConfig, hidden_size=37)
|
| 303 |
+
|
| 304 |
+
def test_config(self):
|
| 305 |
+
self.config_tester.run_common_tests()
|
| 306 |
+
|
| 307 |
+
@unittest.skip("Chameleon forces some token ids to be -inf!")
|
| 308 |
+
def test_batching_equivalence(self):
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
| 312 |
+
def test_cpu_offload(self):
|
| 313 |
+
pass
|
| 314 |
+
|
| 315 |
+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
| 316 |
+
def test_disk_offload_bin(self):
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
@unittest.skip("Chameleon cannot do offload because it uses `self.linear.weight` in forward")
|
| 320 |
+
def test_disk_offload_safetensors(self):
|
| 321 |
+
pass
|
| 322 |
+
|
| 323 |
+
@unittest.skip("Chameleon VQ model cannot be squishes more due to hardcoded layer params in model code")
|
| 324 |
+
def test_model_is_small(self):
|
| 325 |
+
pass
|
| 326 |
+
|
| 327 |
+
def test_mismatching_num_image_tokens(self):
|
| 328 |
+
"""
|
| 329 |
+
Tests that VLMs through an error with explicit message saying what is wrong
|
| 330 |
+
when number of images don't match number of image tokens in the text.
|
| 331 |
+
Also we need to test multi-image cases when one prompr has multiple image tokens.
|
| 332 |
+
"""
|
| 333 |
+
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 334 |
+
for model_class in self.all_model_classes:
|
| 335 |
+
model = model_class(config).to(torch_device)
|
| 336 |
+
curr_input_dict = copy.deepcopy(input_dict) # the below tests modify dict in-place
|
| 337 |
+
_ = model(**curr_input_dict) # successful forward with no modifications
|
| 338 |
+
|
| 339 |
+
# remove one image but leave the image token in text
|
| 340 |
+
curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-1:, ...]
|
| 341 |
+
with self.assertRaises(ValueError):
|
| 342 |
+
_ = model(**curr_input_dict)
|
| 343 |
+
|
| 344 |
+
# simulate multi-image case by concatenating inputs where each has exactly one image/image-token
|
| 345 |
+
input_ids = curr_input_dict["input_ids"][:1]
|
| 346 |
+
pixel_values = curr_input_dict["pixel_values"][:1]
|
| 347 |
+
input_ids = torch.cat([input_ids, input_ids], dim=0)
|
| 348 |
+
|
| 349 |
+
# one image and two image tokens raise an error
|
| 350 |
+
with self.assertRaises(ValueError):
|
| 351 |
+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
| 352 |
+
|
| 353 |
+
# two images and two image tokens don't raise an error
|
| 354 |
+
pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
|
| 355 |
+
_ = model(input_ids=input_ids, pixel_values=pixel_values)
|
| 356 |
+
|
| 357 |
+
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
| 358 |
+
def test_inputs_embeds(self):
|
| 359 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 360 |
+
|
| 361 |
+
for model_class in self.all_model_classes:
|
| 362 |
+
model = model_class(config)
|
| 363 |
+
model.to(torch_device)
|
| 364 |
+
model.eval()
|
| 365 |
+
|
| 366 |
+
inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 367 |
+
|
| 368 |
+
input_ids = inputs["input_ids"]
|
| 369 |
+
del inputs["input_ids"]
|
| 370 |
+
del inputs["pixel_values"]
|
| 371 |
+
|
| 372 |
+
wte = model.get_input_embeddings()
|
| 373 |
+
inputs["inputs_embeds"] = wte(input_ids)
|
| 374 |
+
|
| 375 |
+
with torch.no_grad():
|
| 376 |
+
model(**inputs)
|
| 377 |
+
|
| 378 |
+
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
|
| 379 |
+
# while some other models require pixel_values to be present
|
| 380 |
+
def test_inputs_embeds_matches_input_ids(self):
|
| 381 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 382 |
+
|
| 383 |
+
for model_class in self.all_model_classes:
|
| 384 |
+
model = model_class(config)
|
| 385 |
+
model.to(torch_device)
|
| 386 |
+
model.eval()
|
| 387 |
+
|
| 388 |
+
inputs = self._prepare_for_class(inputs_dict, model_class)
|
| 389 |
+
input_ids = inputs["input_ids"]
|
| 390 |
+
del inputs["input_ids"]
|
| 391 |
+
del inputs["pixel_values"]
|
| 392 |
+
|
| 393 |
+
inputs_embeds = model.get_input_embeddings()(input_ids)
|
| 394 |
+
|
| 395 |
+
with torch.no_grad():
|
| 396 |
+
out_ids = model(input_ids=input_ids, **inputs)[0]
|
| 397 |
+
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
|
| 398 |
+
torch.testing.assert_close(out_embeds, out_ids)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@require_torch
|
| 402 |
+
class ChameleonIntegrationTest(unittest.TestCase):
|
| 403 |
+
@slow
|
| 404 |
+
@require_bitsandbytes
|
| 405 |
+
@require_read_token
|
| 406 |
+
def test_model_7b(self):
|
| 407 |
+
model = ChameleonForConditionalGeneration.from_pretrained(
|
| 408 |
+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
| 409 |
+
)
|
| 410 |
+
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
| 411 |
+
|
| 412 |
+
image = Image.open(
|
| 413 |
+
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
| 414 |
+
)
|
| 415 |
+
prompt = "<image>Describe what do you see here and tell me about the history behind it?"
|
| 416 |
+
|
| 417 |
+
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16)
|
| 418 |
+
|
| 419 |
+
# greedy generation outputs
|
| 420 |
+
EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in'] # fmt: skip
|
| 421 |
+
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
| 422 |
+
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 423 |
+
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
| 424 |
+
|
| 425 |
+
@slow
|
| 426 |
+
@require_bitsandbytes
|
| 427 |
+
@require_read_token
|
| 428 |
+
def test_model_7b_batched(self):
|
| 429 |
+
model = ChameleonForConditionalGeneration.from_pretrained(
|
| 430 |
+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
| 431 |
+
)
|
| 432 |
+
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
| 433 |
+
|
| 434 |
+
image = Image.open(
|
| 435 |
+
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
| 436 |
+
)
|
| 437 |
+
image_2 = Image.open(
|
| 438 |
+
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
|
| 439 |
+
)
|
| 440 |
+
prompts = [
|
| 441 |
+
"<image>Describe what do you see here and tell me about the history behind it?",
|
| 442 |
+
"What constellation is this image showing?<image>",
|
| 443 |
+
]
|
| 444 |
+
|
| 445 |
+
inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to(
|
| 446 |
+
model.device, torch.float16
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# greedy generation outputs
|
| 450 |
+
EXPECTED_TEXT_COMPLETION = [
|
| 451 |
+
'Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue dot in the center representing the star Alpha Centauri. The star map is a representation of the night sky, showing the positions of stars in',
|
| 452 |
+
'What constellation is this image showing?The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.The image shows the constellation of Orion.'
|
| 453 |
+
] # fmt: skip
|
| 454 |
+
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
| 455 |
+
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 456 |
+
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
| 457 |
+
|
| 458 |
+
@slow
|
| 459 |
+
@require_bitsandbytes
|
| 460 |
+
@require_read_token
|
| 461 |
+
def test_model_7b_multi_image(self):
|
| 462 |
+
model = ChameleonForConditionalGeneration.from_pretrained(
|
| 463 |
+
"facebook/chameleon-7b", load_in_4bit=True, device_map="auto"
|
| 464 |
+
)
|
| 465 |
+
processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
|
| 466 |
+
|
| 467 |
+
image = Image.open(
|
| 468 |
+
requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw
|
| 469 |
+
)
|
| 470 |
+
image_2 = Image.open(
|
| 471 |
+
requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw
|
| 472 |
+
)
|
| 473 |
+
prompt = "What do these two images have in common?<image><image>"
|
| 474 |
+
|
| 475 |
+
inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16)
|
| 476 |
+
|
| 477 |
+
# greedy generation outputs
|
| 478 |
+
EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between the night sky and the internet. The first image shows a starry night sky, with the stars arranged in a pattern that resembles the structure of the internet. The'] # fmt: skip
|
| 479 |
+
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False)
|
| 480 |
+
text = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 481 |
+
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
|
docs/transformers/tests/models/chameleon/test_processor_chameleon.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch chameleon model."""
|
| 15 |
+
|
| 16 |
+
import tempfile
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from transformers import ChameleonProcessor, LlamaTokenizer
|
| 20 |
+
from transformers.testing_utils import get_tests_dir
|
| 21 |
+
from transformers.utils import is_vision_available
|
| 22 |
+
|
| 23 |
+
from ...test_processing_common import ProcessorTesterMixin
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
if is_vision_available():
|
| 27 |
+
from transformers import ChameleonImageProcessor
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
| 34 |
+
processor_class = ChameleonProcessor
|
| 35 |
+
|
| 36 |
+
@classmethod
|
| 37 |
+
def setUpClass(cls):
|
| 38 |
+
cls.tmpdirname = tempfile.mkdtemp()
|
| 39 |
+
image_processor = ChameleonImageProcessor()
|
| 40 |
+
tokenizer = LlamaTokenizer(vocab_file=SAMPLE_VOCAB)
|
| 41 |
+
tokenizer.pad_token_id = 0
|
| 42 |
+
tokenizer.sep_token_id = 1
|
| 43 |
+
tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
|
| 44 |
+
processor = cls.processor_class(image_processor=image_processor, tokenizer=tokenizer, image_seq_length=2)
|
| 45 |
+
processor.save_pretrained(cls.tmpdirname)
|
| 46 |
+
cls.image_token = processor.image_token
|
| 47 |
+
|
| 48 |
+
def test_special_mm_token_truncation(self):
|
| 49 |
+
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
|
| 50 |
+
|
| 51 |
+
processor = self.get_processor()
|
| 52 |
+
|
| 53 |
+
input_str = self.prepare_text_inputs(batch_size=2, modality="image")
|
| 54 |
+
image_input = self.prepare_image_inputs(batch_size=2)
|
| 55 |
+
|
| 56 |
+
_ = processor(
|
| 57 |
+
text=input_str,
|
| 58 |
+
images=image_input,
|
| 59 |
+
return_tensors="pt",
|
| 60 |
+
truncation=None,
|
| 61 |
+
padding=True,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
with self.assertRaises(ValueError):
|
| 65 |
+
_ = processor(
|
| 66 |
+
text=input_str,
|
| 67 |
+
images=image_input,
|
| 68 |
+
return_tensors="pt",
|
| 69 |
+
truncation=True,
|
| 70 |
+
padding=True,
|
| 71 |
+
max_length=20,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def prepare_processor_dict():
|
| 76 |
+
return {"image_seq_length": 2} # fmt: skip
|
docs/transformers/tests/models/chinese_clip/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/chinese_clip/test_image_processing_chinese_clip.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 HuggingFace Inc.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
from transformers.testing_utils import require_torch, require_vision
|
| 19 |
+
from transformers.utils import is_torchvision_available, is_vision_available
|
| 20 |
+
|
| 21 |
+
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_vision_available():
|
| 25 |
+
from transformers import ChineseCLIPImageProcessor
|
| 26 |
+
|
| 27 |
+
if is_torchvision_available():
|
| 28 |
+
from transformers import ChineseCLIPImageProcessorFast
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ChineseCLIPImageProcessingTester:
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
parent,
|
| 35 |
+
batch_size=7,
|
| 36 |
+
num_channels=3,
|
| 37 |
+
image_size=18,
|
| 38 |
+
min_resolution=30,
|
| 39 |
+
max_resolution=400,
|
| 40 |
+
do_resize=True,
|
| 41 |
+
size=None,
|
| 42 |
+
do_center_crop=True,
|
| 43 |
+
crop_size=None,
|
| 44 |
+
do_normalize=True,
|
| 45 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
| 46 |
+
image_std=[0.26862954, 0.26130258, 0.27577711],
|
| 47 |
+
do_convert_rgb=True,
|
| 48 |
+
):
|
| 49 |
+
size = size if size is not None else {"height": 224, "width": 224}
|
| 50 |
+
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
| 51 |
+
self.parent = parent
|
| 52 |
+
self.batch_size = batch_size
|
| 53 |
+
self.num_channels = num_channels
|
| 54 |
+
self.image_size = image_size
|
| 55 |
+
self.min_resolution = min_resolution
|
| 56 |
+
self.max_resolution = max_resolution
|
| 57 |
+
self.do_resize = do_resize
|
| 58 |
+
self.size = size
|
| 59 |
+
self.do_center_crop = do_center_crop
|
| 60 |
+
self.crop_size = crop_size
|
| 61 |
+
self.do_normalize = do_normalize
|
| 62 |
+
self.image_mean = image_mean
|
| 63 |
+
self.image_std = image_std
|
| 64 |
+
self.do_convert_rgb = do_convert_rgb
|
| 65 |
+
|
| 66 |
+
def prepare_image_processor_dict(self):
|
| 67 |
+
return {
|
| 68 |
+
"do_resize": self.do_resize,
|
| 69 |
+
"size": self.size,
|
| 70 |
+
"do_center_crop": self.do_center_crop,
|
| 71 |
+
"crop_size": self.crop_size,
|
| 72 |
+
"do_normalize": self.do_normalize,
|
| 73 |
+
"image_mean": self.image_mean,
|
| 74 |
+
"image_std": self.image_std,
|
| 75 |
+
"do_convert_rgb": self.do_convert_rgb,
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
def expected_output_image_shape(self, images):
|
| 79 |
+
return 3, self.crop_size["height"], self.crop_size["width"]
|
| 80 |
+
|
| 81 |
+
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
| 82 |
+
return prepare_image_inputs(
|
| 83 |
+
batch_size=self.batch_size,
|
| 84 |
+
num_channels=self.num_channels,
|
| 85 |
+
min_resolution=self.min_resolution,
|
| 86 |
+
max_resolution=self.max_resolution,
|
| 87 |
+
equal_resolution=equal_resolution,
|
| 88 |
+
numpify=numpify,
|
| 89 |
+
torchify=torchify,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@require_torch
|
| 94 |
+
@require_vision
|
| 95 |
+
class ChineseCLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
| 96 |
+
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
|
| 97 |
+
fast_image_processing_class = ChineseCLIPImageProcessorFast if is_torchvision_available() else None
|
| 98 |
+
|
| 99 |
+
def setUp(self):
|
| 100 |
+
super().setUp()
|
| 101 |
+
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, do_center_crop=True)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def image_processor_dict(self):
|
| 105 |
+
return self.image_processor_tester.prepare_image_processor_dict()
|
| 106 |
+
|
| 107 |
+
def test_image_processor_properties(self):
|
| 108 |
+
for image_processing_class in self.image_processor_list:
|
| 109 |
+
image_processing = image_processing_class(**self.image_processor_dict)
|
| 110 |
+
self.assertTrue(hasattr(image_processing, "do_resize"))
|
| 111 |
+
self.assertTrue(hasattr(image_processing, "size"))
|
| 112 |
+
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
| 113 |
+
self.assertTrue(hasattr(image_processing, "center_crop"))
|
| 114 |
+
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
| 115 |
+
self.assertTrue(hasattr(image_processing, "image_mean"))
|
| 116 |
+
self.assertTrue(hasattr(image_processing, "image_std"))
|
| 117 |
+
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
| 118 |
+
|
| 119 |
+
def test_image_processor_from_dict_with_kwargs(self):
|
| 120 |
+
for image_processing_class in self.image_processor_list:
|
| 121 |
+
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
| 122 |
+
self.assertEqual(image_processor.size, {"height": 224, "width": 224})
|
| 123 |
+
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
| 124 |
+
|
| 125 |
+
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
| 126 |
+
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
| 127 |
+
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
| 128 |
+
|
| 129 |
+
@unittest.skip(
|
| 130 |
+
reason="ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
|
| 131 |
+
) # FIXME Amy
|
| 132 |
+
def test_call_numpy_4_channels(self):
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@require_torch
|
| 137 |
+
@require_vision
|
| 138 |
+
class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unittest.TestCase):
|
| 139 |
+
image_processing_class = ChineseCLIPImageProcessor if is_vision_available() else None
|
| 140 |
+
fast_image_processing_class = ChineseCLIPImageProcessorFast if is_torchvision_available() else None
|
| 141 |
+
|
| 142 |
+
def setUp(self):
|
| 143 |
+
super().setUp()
|
| 144 |
+
self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True)
|
| 145 |
+
self.expected_encoded_image_num_channels = 3
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def image_processor_dict(self):
|
| 149 |
+
return self.image_processor_tester.prepare_image_processor_dict()
|
| 150 |
+
|
| 151 |
+
def test_image_processor_properties(self):
|
| 152 |
+
for image_processing_class in self.image_processor_list:
|
| 153 |
+
image_processing = image_processing_class(**self.image_processor_dict)
|
| 154 |
+
self.assertTrue(hasattr(image_processing, "do_resize"))
|
| 155 |
+
self.assertTrue(hasattr(image_processing, "size"))
|
| 156 |
+
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
| 157 |
+
self.assertTrue(hasattr(image_processing, "center_crop"))
|
| 158 |
+
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
| 159 |
+
self.assertTrue(hasattr(image_processing, "image_mean"))
|
| 160 |
+
self.assertTrue(hasattr(image_processing, "image_std"))
|
| 161 |
+
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
| 162 |
+
|
| 163 |
+
@unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy
|
| 164 |
+
def test_call_numpy(self):
|
| 165 |
+
return super().test_call_numpy()
|
| 166 |
+
|
| 167 |
+
@unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy
|
| 168 |
+
def test_call_pytorch(self):
|
| 169 |
+
return super().test_call_torch()
|
| 170 |
+
|
| 171 |
+
@unittest.skip(
|
| 172 |
+
reason="ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet"
|
| 173 |
+
) # FIXME Amy
|
| 174 |
+
def test_call_numpy_4_channels(self):
|
| 175 |
+
pass
|
docs/transformers/tests/models/chinese_clip/test_modeling_chinese_clip.py
ADDED
|
@@ -0,0 +1,762 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch Chinese-CLIP model."""
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
|
| 24 |
+
from transformers import ChineseCLIPConfig, ChineseCLIPTextConfig, ChineseCLIPVisionConfig
|
| 25 |
+
from transformers.models.auto import get_values
|
| 26 |
+
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
| 27 |
+
from transformers.utils import is_torch_available, is_vision_available
|
| 28 |
+
|
| 29 |
+
from ...test_configuration_common import ConfigTester
|
| 30 |
+
from ...test_modeling_common import (
|
| 31 |
+
ModelTesterMixin,
|
| 32 |
+
_config_zero_init,
|
| 33 |
+
floats_tensor,
|
| 34 |
+
ids_tensor,
|
| 35 |
+
random_attention_mask,
|
| 36 |
+
)
|
| 37 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if is_torch_available():
|
| 41 |
+
import torch
|
| 42 |
+
from torch import nn
|
| 43 |
+
|
| 44 |
+
from transformers import (
|
| 45 |
+
MODEL_FOR_PRETRAINING_MAPPING,
|
| 46 |
+
ChineseCLIPModel,
|
| 47 |
+
ChineseCLIPTextModel,
|
| 48 |
+
ChineseCLIPVisionModel,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if is_vision_available():
|
| 53 |
+
from PIL import Image
|
| 54 |
+
|
| 55 |
+
from transformers import ChineseCLIPProcessor
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ChineseCLIPTextModelTester:
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
parent,
|
| 62 |
+
batch_size=13,
|
| 63 |
+
seq_length=7,
|
| 64 |
+
is_training=True,
|
| 65 |
+
use_input_mask=True,
|
| 66 |
+
use_token_type_ids=True,
|
| 67 |
+
use_labels=True,
|
| 68 |
+
vocab_size=99,
|
| 69 |
+
hidden_size=32,
|
| 70 |
+
num_hidden_layers=2,
|
| 71 |
+
num_attention_heads=4,
|
| 72 |
+
intermediate_size=37,
|
| 73 |
+
hidden_act="gelu",
|
| 74 |
+
hidden_dropout_prob=0.1,
|
| 75 |
+
attention_probs_dropout_prob=0.1,
|
| 76 |
+
max_position_embeddings=512,
|
| 77 |
+
type_vocab_size=16,
|
| 78 |
+
type_sequence_label_size=2,
|
| 79 |
+
initializer_range=0.02,
|
| 80 |
+
num_labels=3,
|
| 81 |
+
num_choices=4,
|
| 82 |
+
scope=None,
|
| 83 |
+
):
|
| 84 |
+
self.parent = parent
|
| 85 |
+
self.batch_size = batch_size
|
| 86 |
+
self.seq_length = seq_length
|
| 87 |
+
self.is_training = is_training
|
| 88 |
+
self.use_input_mask = use_input_mask
|
| 89 |
+
self.use_token_type_ids = use_token_type_ids
|
| 90 |
+
self.use_labels = use_labels
|
| 91 |
+
self.vocab_size = vocab_size
|
| 92 |
+
self.hidden_size = hidden_size
|
| 93 |
+
self.num_hidden_layers = num_hidden_layers
|
| 94 |
+
self.num_attention_heads = num_attention_heads
|
| 95 |
+
self.intermediate_size = intermediate_size
|
| 96 |
+
self.hidden_act = hidden_act
|
| 97 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 98 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 99 |
+
self.max_position_embeddings = max_position_embeddings
|
| 100 |
+
self.type_vocab_size = type_vocab_size
|
| 101 |
+
self.type_sequence_label_size = type_sequence_label_size
|
| 102 |
+
self.initializer_range = initializer_range
|
| 103 |
+
self.num_labels = num_labels
|
| 104 |
+
self.num_choices = num_choices
|
| 105 |
+
self.scope = scope
|
| 106 |
+
|
| 107 |
+
def prepare_config_and_inputs(self):
|
| 108 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 109 |
+
|
| 110 |
+
input_mask = None
|
| 111 |
+
if self.use_input_mask:
|
| 112 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 113 |
+
|
| 114 |
+
token_type_ids = None
|
| 115 |
+
if self.use_token_type_ids:
|
| 116 |
+
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
| 117 |
+
|
| 118 |
+
sequence_labels = None
|
| 119 |
+
token_labels = None
|
| 120 |
+
choice_labels = None
|
| 121 |
+
if self.use_labels:
|
| 122 |
+
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
| 123 |
+
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
| 124 |
+
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
| 125 |
+
|
| 126 |
+
config = self.get_config()
|
| 127 |
+
|
| 128 |
+
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 129 |
+
|
| 130 |
+
def get_config(self):
|
| 131 |
+
"""
|
| 132 |
+
Returns a tiny configuration by default.
|
| 133 |
+
"""
|
| 134 |
+
return ChineseCLIPTextConfig(
|
| 135 |
+
vocab_size=self.vocab_size,
|
| 136 |
+
hidden_size=self.hidden_size,
|
| 137 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 138 |
+
num_attention_heads=self.num_attention_heads,
|
| 139 |
+
intermediate_size=self.intermediate_size,
|
| 140 |
+
hidden_act=self.hidden_act,
|
| 141 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 142 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 143 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 144 |
+
type_vocab_size=self.type_vocab_size,
|
| 145 |
+
is_decoder=False,
|
| 146 |
+
initializer_range=self.initializer_range,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def prepare_config_and_inputs_for_decoder(self):
|
| 150 |
+
(
|
| 151 |
+
config,
|
| 152 |
+
input_ids,
|
| 153 |
+
token_type_ids,
|
| 154 |
+
input_mask,
|
| 155 |
+
sequence_labels,
|
| 156 |
+
token_labels,
|
| 157 |
+
choice_labels,
|
| 158 |
+
) = self.prepare_config_and_inputs()
|
| 159 |
+
|
| 160 |
+
config.is_decoder = True
|
| 161 |
+
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
|
| 162 |
+
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
|
| 163 |
+
|
| 164 |
+
return (
|
| 165 |
+
config,
|
| 166 |
+
input_ids,
|
| 167 |
+
token_type_ids,
|
| 168 |
+
input_mask,
|
| 169 |
+
sequence_labels,
|
| 170 |
+
token_labels,
|
| 171 |
+
choice_labels,
|
| 172 |
+
encoder_hidden_states,
|
| 173 |
+
encoder_attention_mask,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def create_and_check_model(
|
| 177 |
+
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
| 178 |
+
):
|
| 179 |
+
model = ChineseCLIPTextModel(config=config)
|
| 180 |
+
model.to(torch_device)
|
| 181 |
+
model.eval()
|
| 182 |
+
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
| 183 |
+
result = model(input_ids, token_type_ids=token_type_ids)
|
| 184 |
+
result = model(input_ids)
|
| 185 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 186 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 187 |
+
|
| 188 |
+
def create_and_check_model_as_decoder(
|
| 189 |
+
self,
|
| 190 |
+
config,
|
| 191 |
+
input_ids,
|
| 192 |
+
token_type_ids,
|
| 193 |
+
input_mask,
|
| 194 |
+
sequence_labels,
|
| 195 |
+
token_labels,
|
| 196 |
+
choice_labels,
|
| 197 |
+
encoder_hidden_states,
|
| 198 |
+
encoder_attention_mask,
|
| 199 |
+
):
|
| 200 |
+
config.add_cross_attention = True
|
| 201 |
+
model = ChineseCLIPTextModel(config)
|
| 202 |
+
model.to(torch_device)
|
| 203 |
+
model.eval()
|
| 204 |
+
result = model(
|
| 205 |
+
input_ids,
|
| 206 |
+
attention_mask=input_mask,
|
| 207 |
+
token_type_ids=token_type_ids,
|
| 208 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 209 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 210 |
+
)
|
| 211 |
+
result = model(
|
| 212 |
+
input_ids,
|
| 213 |
+
attention_mask=input_mask,
|
| 214 |
+
token_type_ids=token_type_ids,
|
| 215 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 216 |
+
)
|
| 217 |
+
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
| 218 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 219 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 220 |
+
|
| 221 |
+
def prepare_config_and_inputs_for_common(self):
|
| 222 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 223 |
+
(
|
| 224 |
+
config,
|
| 225 |
+
input_ids,
|
| 226 |
+
token_type_ids,
|
| 227 |
+
input_mask,
|
| 228 |
+
sequence_labels,
|
| 229 |
+
token_labels,
|
| 230 |
+
choice_labels,
|
| 231 |
+
) = config_and_inputs
|
| 232 |
+
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
| 233 |
+
return config, inputs_dict
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ChineseCLIPVisionModelTester:
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
parent,
|
| 240 |
+
batch_size=12,
|
| 241 |
+
image_size=30,
|
| 242 |
+
patch_size=2,
|
| 243 |
+
num_channels=3,
|
| 244 |
+
is_training=True,
|
| 245 |
+
hidden_size=32,
|
| 246 |
+
projection_dim=32,
|
| 247 |
+
num_hidden_layers=2,
|
| 248 |
+
num_attention_heads=4,
|
| 249 |
+
intermediate_size=37,
|
| 250 |
+
dropout=0.1,
|
| 251 |
+
attention_dropout=0.1,
|
| 252 |
+
initializer_range=0.02,
|
| 253 |
+
scope=None,
|
| 254 |
+
):
|
| 255 |
+
self.parent = parent
|
| 256 |
+
self.batch_size = batch_size
|
| 257 |
+
self.image_size = image_size
|
| 258 |
+
self.patch_size = patch_size
|
| 259 |
+
self.num_channels = num_channels
|
| 260 |
+
self.is_training = is_training
|
| 261 |
+
self.hidden_size = hidden_size
|
| 262 |
+
self.projection_dim = projection_dim
|
| 263 |
+
self.num_hidden_layers = num_hidden_layers
|
| 264 |
+
self.num_attention_heads = num_attention_heads
|
| 265 |
+
self.intermediate_size = intermediate_size
|
| 266 |
+
self.dropout = dropout
|
| 267 |
+
self.attention_dropout = attention_dropout
|
| 268 |
+
self.initializer_range = initializer_range
|
| 269 |
+
self.scope = scope
|
| 270 |
+
|
| 271 |
+
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 272 |
+
num_patches = (image_size // patch_size) ** 2
|
| 273 |
+
self.seq_length = num_patches + 1
|
| 274 |
+
|
| 275 |
+
def prepare_config_and_inputs(self):
|
| 276 |
+
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
| 277 |
+
config = self.get_config()
|
| 278 |
+
|
| 279 |
+
return config, pixel_values
|
| 280 |
+
|
| 281 |
+
def get_config(self):
|
| 282 |
+
return ChineseCLIPVisionConfig(
|
| 283 |
+
image_size=self.image_size,
|
| 284 |
+
patch_size=self.patch_size,
|
| 285 |
+
num_channels=self.num_channels,
|
| 286 |
+
hidden_size=self.hidden_size,
|
| 287 |
+
projection_dim=self.projection_dim,
|
| 288 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 289 |
+
num_attention_heads=self.num_attention_heads,
|
| 290 |
+
intermediate_size=self.intermediate_size,
|
| 291 |
+
dropout=self.dropout,
|
| 292 |
+
attention_dropout=self.attention_dropout,
|
| 293 |
+
initializer_range=self.initializer_range,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def create_and_check_model(self, config, pixel_values):
|
| 297 |
+
model = ChineseCLIPVisionModel(config=config)
|
| 298 |
+
model.to(torch_device)
|
| 299 |
+
model.eval()
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
result = model(pixel_values)
|
| 302 |
+
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
| 303 |
+
image_size = (self.image_size, self.image_size)
|
| 304 |
+
patch_size = (self.patch_size, self.patch_size)
|
| 305 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 306 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
| 307 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 308 |
+
|
| 309 |
+
def prepare_config_and_inputs_for_common(self):
|
| 310 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 311 |
+
config, pixel_values = config_and_inputs
|
| 312 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 313 |
+
return config, inputs_dict
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@require_torch
|
| 317 |
+
class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
|
| 318 |
+
all_model_classes = (ChineseCLIPTextModel,) if is_torch_available() else ()
|
| 319 |
+
fx_compatible = False
|
| 320 |
+
|
| 321 |
+
# special case for ForPreTraining model
|
| 322 |
+
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
| 323 |
+
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
| 324 |
+
|
| 325 |
+
if return_labels:
|
| 326 |
+
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
|
| 327 |
+
inputs_dict["labels"] = torch.zeros(
|
| 328 |
+
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
| 329 |
+
)
|
| 330 |
+
inputs_dict["next_sentence_label"] = torch.zeros(
|
| 331 |
+
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
| 332 |
+
)
|
| 333 |
+
return inputs_dict
|
| 334 |
+
|
| 335 |
+
def setUp(self):
|
| 336 |
+
self.model_tester = ChineseCLIPTextModelTester(self)
|
| 337 |
+
self.config_tester = ConfigTester(self, config_class=ChineseCLIPTextConfig, hidden_size=37)
|
| 338 |
+
|
| 339 |
+
def test_config(self):
|
| 340 |
+
self.config_tester.run_common_tests()
|
| 341 |
+
|
| 342 |
+
def test_model(self):
|
| 343 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 344 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 345 |
+
|
| 346 |
+
def test_model_various_embeddings(self):
|
| 347 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 348 |
+
for type in ["absolute", "relative_key", "relative_key_query"]:
|
| 349 |
+
config_and_inputs[0].position_embedding_type = type
|
| 350 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 351 |
+
|
| 352 |
+
def test_model_as_decoder(self):
|
| 353 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
|
| 354 |
+
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
|
| 355 |
+
|
| 356 |
+
def test_model_as_decoder_with_default_input_mask(self):
|
| 357 |
+
(
|
| 358 |
+
config,
|
| 359 |
+
input_ids,
|
| 360 |
+
token_type_ids,
|
| 361 |
+
input_mask,
|
| 362 |
+
sequence_labels,
|
| 363 |
+
token_labels,
|
| 364 |
+
choice_labels,
|
| 365 |
+
encoder_hidden_states,
|
| 366 |
+
encoder_attention_mask,
|
| 367 |
+
) = self.model_tester.prepare_config_and_inputs_for_decoder()
|
| 368 |
+
|
| 369 |
+
input_mask = None
|
| 370 |
+
|
| 371 |
+
self.model_tester.create_and_check_model_as_decoder(
|
| 372 |
+
config,
|
| 373 |
+
input_ids,
|
| 374 |
+
token_type_ids,
|
| 375 |
+
input_mask,
|
| 376 |
+
sequence_labels,
|
| 377 |
+
token_labels,
|
| 378 |
+
choice_labels,
|
| 379 |
+
encoder_hidden_states,
|
| 380 |
+
encoder_attention_mask,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
@slow
|
| 384 |
+
def test_model_from_pretrained(self):
|
| 385 |
+
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
|
| 386 |
+
model = ChineseCLIPTextModel.from_pretrained(model_name)
|
| 387 |
+
self.assertIsNotNone(model)
|
| 388 |
+
|
| 389 |
+
@unittest.skip
|
| 390 |
+
def test_training(self):
|
| 391 |
+
pass
|
| 392 |
+
|
| 393 |
+
@unittest.skip
|
| 394 |
+
def test_training_gradient_checkpointing(self):
|
| 395 |
+
pass
|
| 396 |
+
|
| 397 |
+
@unittest.skip(
|
| 398 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 399 |
+
)
|
| 400 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
@unittest.skip(
|
| 404 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 405 |
+
)
|
| 406 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 407 |
+
pass
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
@require_torch
|
| 411 |
+
class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
| 412 |
+
"""
|
| 413 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CHINESE_CLIP does not use input_ids, inputs_embeds,
|
| 414 |
+
attention_mask and seq_length.
|
| 415 |
+
"""
|
| 416 |
+
|
| 417 |
+
all_model_classes = (ChineseCLIPVisionModel,) if is_torch_available() else ()
|
| 418 |
+
fx_compatible = False
|
| 419 |
+
test_pruning = False
|
| 420 |
+
test_resize_embeddings = False
|
| 421 |
+
test_head_masking = False
|
| 422 |
+
|
| 423 |
+
def setUp(self):
|
| 424 |
+
self.model_tester = ChineseCLIPVisionModelTester(self)
|
| 425 |
+
self.config_tester = ConfigTester(
|
| 426 |
+
self, config_class=ChineseCLIPVisionConfig, has_text_modality=False, hidden_size=37
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def test_config(self):
|
| 430 |
+
self.config_tester.run_common_tests()
|
| 431 |
+
|
| 432 |
+
@unittest.skip(reason="CHINESE_CLIP does not use inputs_embeds")
|
| 433 |
+
def test_inputs_embeds(self):
|
| 434 |
+
pass
|
| 435 |
+
|
| 436 |
+
def test_model_get_set_embeddings(self):
|
| 437 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 438 |
+
|
| 439 |
+
for model_class in self.all_model_classes:
|
| 440 |
+
model = model_class(config)
|
| 441 |
+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
| 442 |
+
x = model.get_output_embeddings()
|
| 443 |
+
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
| 444 |
+
|
| 445 |
+
def test_forward_signature(self):
|
| 446 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 447 |
+
|
| 448 |
+
for model_class in self.all_model_classes:
|
| 449 |
+
model = model_class(config)
|
| 450 |
+
signature = inspect.signature(model.forward)
|
| 451 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 452 |
+
arg_names = [*signature.parameters.keys()]
|
| 453 |
+
|
| 454 |
+
expected_arg_names = ["pixel_values"]
|
| 455 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 456 |
+
|
| 457 |
+
def test_model(self):
|
| 458 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 459 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 460 |
+
|
| 461 |
+
@unittest.skip
|
| 462 |
+
def test_training(self):
|
| 463 |
+
pass
|
| 464 |
+
|
| 465 |
+
@unittest.skip
|
| 466 |
+
def test_training_gradient_checkpointing(self):
|
| 467 |
+
pass
|
| 468 |
+
|
| 469 |
+
@unittest.skip(
|
| 470 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 471 |
+
)
|
| 472 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 473 |
+
pass
|
| 474 |
+
|
| 475 |
+
@unittest.skip(
|
| 476 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 477 |
+
)
|
| 478 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 479 |
+
pass
|
| 480 |
+
|
| 481 |
+
@slow
|
| 482 |
+
def test_model_from_pretrained(self):
|
| 483 |
+
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
|
| 484 |
+
model = ChineseCLIPVisionModel.from_pretrained(model_name)
|
| 485 |
+
self.assertIsNotNone(model)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class ChineseCLIPModelTester:
|
| 489 |
+
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
| 490 |
+
if text_kwargs is None:
|
| 491 |
+
text_kwargs = {}
|
| 492 |
+
if vision_kwargs is None:
|
| 493 |
+
vision_kwargs = {}
|
| 494 |
+
|
| 495 |
+
self.parent = parent
|
| 496 |
+
self.text_model_tester = ChineseCLIPTextModelTester(parent, **text_kwargs)
|
| 497 |
+
self.vision_model_tester = ChineseCLIPVisionModelTester(parent, **vision_kwargs)
|
| 498 |
+
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
| 499 |
+
self.is_training = is_training
|
| 500 |
+
|
| 501 |
+
def prepare_config_and_inputs(self):
|
| 502 |
+
(
|
| 503 |
+
config,
|
| 504 |
+
input_ids,
|
| 505 |
+
token_type_ids,
|
| 506 |
+
attention_mask,
|
| 507 |
+
_,
|
| 508 |
+
__,
|
| 509 |
+
___,
|
| 510 |
+
) = self.text_model_tester.prepare_config_and_inputs()
|
| 511 |
+
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 512 |
+
|
| 513 |
+
config = self.get_config()
|
| 514 |
+
|
| 515 |
+
return config, input_ids, token_type_ids, attention_mask, pixel_values
|
| 516 |
+
|
| 517 |
+
def get_config(self):
|
| 518 |
+
return ChineseCLIPConfig.from_text_vision_configs(
|
| 519 |
+
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
def create_and_check_model(self, config, input_ids, token_type_ids, attention_mask, pixel_values):
|
| 523 |
+
model = ChineseCLIPModel(config).to(torch_device).eval()
|
| 524 |
+
with torch.no_grad():
|
| 525 |
+
result = model(input_ids, pixel_values, attention_mask, token_type_ids)
|
| 526 |
+
self.parent.assertEqual(
|
| 527 |
+
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
| 528 |
+
)
|
| 529 |
+
self.parent.assertEqual(
|
| 530 |
+
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
def prepare_config_and_inputs_for_common(self):
|
| 534 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 535 |
+
config, input_ids, token_type_ids, attention_mask, pixel_values = config_and_inputs
|
| 536 |
+
inputs_dict = {
|
| 537 |
+
"input_ids": input_ids,
|
| 538 |
+
"token_type_ids": token_type_ids,
|
| 539 |
+
"attention_mask": attention_mask,
|
| 540 |
+
"pixel_values": pixel_values,
|
| 541 |
+
"return_loss": True,
|
| 542 |
+
}
|
| 543 |
+
return config, inputs_dict
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
@require_torch
|
| 547 |
+
class ChineseCLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 548 |
+
all_model_classes = (ChineseCLIPModel,) if is_torch_available() else ()
|
| 549 |
+
pipeline_model_mapping = {"feature-extraction": ChineseCLIPModel} if is_torch_available() else {}
|
| 550 |
+
fx_compatible = False
|
| 551 |
+
test_head_masking = False
|
| 552 |
+
test_pruning = False
|
| 553 |
+
test_resize_embeddings = False
|
| 554 |
+
test_attention_outputs = False
|
| 555 |
+
|
| 556 |
+
def setUp(self):
|
| 557 |
+
text_kwargs = {"use_labels": False, "batch_size": 12}
|
| 558 |
+
vision_kwargs = {"batch_size": 12}
|
| 559 |
+
self.model_tester = ChineseCLIPModelTester(self, text_kwargs, vision_kwargs)
|
| 560 |
+
|
| 561 |
+
def test_model(self):
|
| 562 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 563 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 564 |
+
|
| 565 |
+
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
| 566 |
+
def test_hidden_states_output(self):
|
| 567 |
+
pass
|
| 568 |
+
|
| 569 |
+
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
| 570 |
+
def test_inputs_embeds(self):
|
| 571 |
+
pass
|
| 572 |
+
|
| 573 |
+
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
| 574 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 575 |
+
pass
|
| 576 |
+
|
| 577 |
+
@unittest.skip(reason="ChineseCLIPModel does not have input/output embeddings")
|
| 578 |
+
def test_model_get_set_embeddings(self):
|
| 579 |
+
pass
|
| 580 |
+
|
| 581 |
+
# override as the `logit_scale` parameter initialization is different for CHINESE_CLIP
|
| 582 |
+
def test_initialization(self):
|
| 583 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 584 |
+
|
| 585 |
+
configs_no_init = _config_zero_init(config)
|
| 586 |
+
for sub_config_key in ("vision_config", "text_config"):
|
| 587 |
+
sub_config = getattr(configs_no_init, sub_config_key, {})
|
| 588 |
+
setattr(configs_no_init, sub_config_key, _config_zero_init(sub_config))
|
| 589 |
+
for model_class in self.all_model_classes:
|
| 590 |
+
model = model_class(config=configs_no_init)
|
| 591 |
+
for name, param in model.named_parameters():
|
| 592 |
+
if param.requires_grad:
|
| 593 |
+
# check if `logit_scale` is initialized as per the original implementation
|
| 594 |
+
if name == "logit_scale":
|
| 595 |
+
self.assertAlmostEqual(
|
| 596 |
+
param.data.item(),
|
| 597 |
+
np.log(1 / 0.07),
|
| 598 |
+
delta=1e-3,
|
| 599 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 600 |
+
)
|
| 601 |
+
else:
|
| 602 |
+
self.assertIn(
|
| 603 |
+
((param.data.mean() * 1e9).round() / 1e9).item(),
|
| 604 |
+
[0.0, 1.0],
|
| 605 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
def _create_and_check_torchscript(self, config, inputs_dict):
|
| 609 |
+
if not self.test_torchscript:
|
| 610 |
+
self.skipTest(reason="test_torchscript is set to False")
|
| 611 |
+
|
| 612 |
+
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
| 613 |
+
configs_no_init.torchscript = True
|
| 614 |
+
configs_no_init.return_dict = False
|
| 615 |
+
for model_class in self.all_model_classes:
|
| 616 |
+
model = model_class(config=configs_no_init)
|
| 617 |
+
model.to(torch_device)
|
| 618 |
+
model.eval()
|
| 619 |
+
|
| 620 |
+
try:
|
| 621 |
+
input_ids = inputs_dict["input_ids"]
|
| 622 |
+
pixel_values = inputs_dict["pixel_values"] # CHINESE_CLIP needs pixel_values
|
| 623 |
+
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
| 624 |
+
except RuntimeError:
|
| 625 |
+
self.fail("Couldn't trace module.")
|
| 626 |
+
|
| 627 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 628 |
+
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
| 629 |
+
|
| 630 |
+
try:
|
| 631 |
+
torch.jit.save(traced_model, pt_file_name)
|
| 632 |
+
except Exception:
|
| 633 |
+
self.fail("Couldn't save module.")
|
| 634 |
+
|
| 635 |
+
try:
|
| 636 |
+
loaded_model = torch.jit.load(pt_file_name)
|
| 637 |
+
except Exception:
|
| 638 |
+
self.fail("Couldn't load module.")
|
| 639 |
+
|
| 640 |
+
model.to(torch_device)
|
| 641 |
+
model.eval()
|
| 642 |
+
|
| 643 |
+
loaded_model.to(torch_device)
|
| 644 |
+
loaded_model.eval()
|
| 645 |
+
|
| 646 |
+
model_state_dict = model.state_dict()
|
| 647 |
+
loaded_model_state_dict = loaded_model.state_dict()
|
| 648 |
+
|
| 649 |
+
non_persistent_buffers = {}
|
| 650 |
+
for key in loaded_model_state_dict.keys():
|
| 651 |
+
if key not in model_state_dict.keys():
|
| 652 |
+
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
| 653 |
+
|
| 654 |
+
loaded_model_state_dict = {
|
| 655 |
+
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
| 659 |
+
|
| 660 |
+
model_buffers = list(model.buffers())
|
| 661 |
+
for non_persistent_buffer in non_persistent_buffers.values():
|
| 662 |
+
found_buffer = False
|
| 663 |
+
for i, model_buffer in enumerate(model_buffers):
|
| 664 |
+
if torch.equal(non_persistent_buffer, model_buffer):
|
| 665 |
+
found_buffer = True
|
| 666 |
+
break
|
| 667 |
+
|
| 668 |
+
self.assertTrue(found_buffer)
|
| 669 |
+
model_buffers.pop(i)
|
| 670 |
+
|
| 671 |
+
models_equal = True
|
| 672 |
+
for layer_name, p1 in model_state_dict.items():
|
| 673 |
+
p2 = loaded_model_state_dict[layer_name]
|
| 674 |
+
if p1.data.ne(p2.data).sum() > 0:
|
| 675 |
+
models_equal = False
|
| 676 |
+
|
| 677 |
+
self.assertTrue(models_equal)
|
| 678 |
+
|
| 679 |
+
@slow
|
| 680 |
+
def test_model_from_pretrained(self):
|
| 681 |
+
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
|
| 682 |
+
model = ChineseCLIPModel.from_pretrained(model_name)
|
| 683 |
+
self.assertIsNotNone(model)
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# We will verify our results on an image of Pikachu
|
| 687 |
+
def prepare_img():
|
| 688 |
+
url = "https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg"
|
| 689 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 690 |
+
return im
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
@require_vision
|
| 694 |
+
@require_torch
|
| 695 |
+
class ChineseCLIPModelIntegrationTest(unittest.TestCase):
|
| 696 |
+
@slow
|
| 697 |
+
def test_inference(self):
|
| 698 |
+
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
|
| 699 |
+
model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device)
|
| 700 |
+
processor = ChineseCLIPProcessor.from_pretrained(model_name)
|
| 701 |
+
|
| 702 |
+
image = prepare_img()
|
| 703 |
+
inputs = processor(
|
| 704 |
+
text=["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"], images=image, padding=True, return_tensors="pt"
|
| 705 |
+
).to(torch_device)
|
| 706 |
+
|
| 707 |
+
# forward pass
|
| 708 |
+
with torch.no_grad():
|
| 709 |
+
outputs = model(**inputs)
|
| 710 |
+
|
| 711 |
+
# verify the logits
|
| 712 |
+
self.assertEqual(
|
| 713 |
+
outputs.logits_per_image.shape,
|
| 714 |
+
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
| 715 |
+
)
|
| 716 |
+
self.assertEqual(
|
| 717 |
+
outputs.logits_per_text.shape,
|
| 718 |
+
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
probs = outputs.logits_per_image.softmax(dim=1)
|
| 722 |
+
expected_probs = torch.tensor([[1.2686e-03, 5.4499e-02, 6.7968e-04, 9.4355e-01]], device=torch_device)
|
| 723 |
+
|
| 724 |
+
torch.testing.assert_close(probs, expected_probs, rtol=5e-3, atol=5e-3)
|
| 725 |
+
|
| 726 |
+
@slow
|
| 727 |
+
def test_inference_interpolate_pos_encoding(self):
|
| 728 |
+
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
|
| 729 |
+
# allowing to interpolate the pre-trained position embeddings in order to use
|
| 730 |
+
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
| 731 |
+
# to visualize self-attention on higher resolution images.
|
| 732 |
+
model_name = "OFA-Sys/chinese-clip-vit-base-patch16"
|
| 733 |
+
model = ChineseCLIPModel.from_pretrained(model_name).to(torch_device)
|
| 734 |
+
|
| 735 |
+
image_processor = ChineseCLIPProcessor.from_pretrained(
|
| 736 |
+
model_name, size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
| 740 |
+
inputs = image_processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
| 741 |
+
|
| 742 |
+
# interpolate_pos_encodiung false should return value error
|
| 743 |
+
with self.assertRaises(ValueError, msg="doesn't match model"):
|
| 744 |
+
with torch.no_grad():
|
| 745 |
+
model(**inputs, interpolate_pos_encoding=False)
|
| 746 |
+
|
| 747 |
+
# forward pass
|
| 748 |
+
with torch.no_grad():
|
| 749 |
+
outputs = model(**inputs, interpolate_pos_encoding=True)
|
| 750 |
+
|
| 751 |
+
# verify the logits
|
| 752 |
+
expected_shape = torch.Size((1, 122, 768))
|
| 753 |
+
|
| 754 |
+
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
| 755 |
+
|
| 756 |
+
expected_slice = torch.tensor(
|
| 757 |
+
[[-0.3990, 0.2983, -0.1239], [-0.1452, -0.2759, 0.0403], [-0.3149, -0.4763, 0.8555]]
|
| 758 |
+
).to(torch_device)
|
| 759 |
+
|
| 760 |
+
torch.testing.assert_close(
|
| 761 |
+
outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4
|
| 762 |
+
)
|
docs/transformers/tests/models/chinese_clip/test_processor_chinese_clip.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import pytest
|
| 22 |
+
|
| 23 |
+
from transformers import BertTokenizer, BertTokenizerFast
|
| 24 |
+
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES
|
| 25 |
+
from transformers.testing_utils import require_vision
|
| 26 |
+
from transformers.utils import FEATURE_EXTRACTOR_NAME, is_vision_available
|
| 27 |
+
|
| 28 |
+
from ...test_processing_common import ProcessorTesterMixin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_vision_available():
|
| 32 |
+
from transformers import ChineseCLIPImageProcessor, ChineseCLIPProcessor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@require_vision
|
| 36 |
+
class ChineseCLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
| 37 |
+
processor_class = ChineseCLIPProcessor
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def setUpClass(cls):
|
| 41 |
+
cls.tmpdirname = tempfile.mkdtemp()
|
| 42 |
+
|
| 43 |
+
vocab_tokens = [
|
| 44 |
+
"[UNK]",
|
| 45 |
+
"[CLS]",
|
| 46 |
+
"[SEP]",
|
| 47 |
+
"[PAD]",
|
| 48 |
+
"[MASK]",
|
| 49 |
+
"的",
|
| 50 |
+
"价",
|
| 51 |
+
"格",
|
| 52 |
+
"是",
|
| 53 |
+
"15",
|
| 54 |
+
"便",
|
| 55 |
+
"alex",
|
| 56 |
+
"##andra",
|
| 57 |
+
",",
|
| 58 |
+
"。",
|
| 59 |
+
"-",
|
| 60 |
+
"t",
|
| 61 |
+
"shirt",
|
| 62 |
+
]
|
| 63 |
+
cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
| 64 |
+
with open(cls.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
| 65 |
+
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
| 66 |
+
|
| 67 |
+
image_processor_map = {
|
| 68 |
+
"do_resize": True,
|
| 69 |
+
"size": {"height": 224, "width": 224},
|
| 70 |
+
"do_center_crop": True,
|
| 71 |
+
"crop_size": {"height": 18, "width": 18},
|
| 72 |
+
"do_normalize": True,
|
| 73 |
+
"image_mean": [0.48145466, 0.4578275, 0.40821073],
|
| 74 |
+
"image_std": [0.26862954, 0.26130258, 0.27577711],
|
| 75 |
+
"do_convert_rgb": True,
|
| 76 |
+
}
|
| 77 |
+
cls.image_processor_file = os.path.join(cls.tmpdirname, FEATURE_EXTRACTOR_NAME)
|
| 78 |
+
with open(cls.image_processor_file, "w", encoding="utf-8") as fp:
|
| 79 |
+
json.dump(image_processor_map, fp)
|
| 80 |
+
|
| 81 |
+
tokenizer = cls.get_tokenizer()
|
| 82 |
+
image_processor = cls.get_image_processor()
|
| 83 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 84 |
+
processor.save_pretrained(cls.tmpdirname)
|
| 85 |
+
|
| 86 |
+
@classmethod
|
| 87 |
+
def get_tokenizer(cls, **kwargs):
|
| 88 |
+
return BertTokenizer.from_pretrained(cls.tmpdirname, **kwargs)
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def get_rust_tokenizer(cls, **kwargs):
|
| 92 |
+
return BertTokenizerFast.from_pretrained(cls.tmpdirname, **kwargs)
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def get_image_processor(cls, **kwargs):
|
| 96 |
+
return ChineseCLIPImageProcessor.from_pretrained(cls.tmpdirname, **kwargs)
|
| 97 |
+
|
| 98 |
+
@classmethod
|
| 99 |
+
def tearDownClass(cls):
|
| 100 |
+
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
| 101 |
+
|
| 102 |
+
def test_save_load_pretrained_default(self):
|
| 103 |
+
tokenizer_slow = self.get_tokenizer()
|
| 104 |
+
tokenizer_fast = self.get_rust_tokenizer()
|
| 105 |
+
image_processor = self.get_image_processor()
|
| 106 |
+
|
| 107 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 108 |
+
processor_slow = ChineseCLIPProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
|
| 109 |
+
processor_slow.save_pretrained(tmpdir)
|
| 110 |
+
processor_slow = ChineseCLIPProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
| 111 |
+
|
| 112 |
+
processor_fast = ChineseCLIPProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
|
| 113 |
+
processor_fast.save_pretrained(tmpdir)
|
| 114 |
+
processor_fast = ChineseCLIPProcessor.from_pretrained(self.tmpdirname)
|
| 115 |
+
|
| 116 |
+
self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
|
| 117 |
+
self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
|
| 118 |
+
self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
|
| 119 |
+
self.assertIsInstance(processor_slow.tokenizer, BertTokenizer)
|
| 120 |
+
self.assertIsInstance(processor_fast.tokenizer, BertTokenizerFast)
|
| 121 |
+
|
| 122 |
+
self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
|
| 123 |
+
self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
|
| 124 |
+
self.assertIsInstance(processor_slow.image_processor, ChineseCLIPImageProcessor)
|
| 125 |
+
self.assertIsInstance(processor_fast.image_processor, ChineseCLIPImageProcessor)
|
| 126 |
+
|
| 127 |
+
def test_save_load_pretrained_additional_features(self):
|
| 128 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 129 |
+
processor = ChineseCLIPProcessor(
|
| 130 |
+
tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor()
|
| 131 |
+
)
|
| 132 |
+
processor.save_pretrained(tmpdir)
|
| 133 |
+
|
| 134 |
+
tokenizer_add_kwargs = self.get_tokenizer(cls_token="(CLS)", sep_token="(SEP)")
|
| 135 |
+
image_processor_add_kwargs = self.get_image_processor(do_normalize=False)
|
| 136 |
+
|
| 137 |
+
processor = ChineseCLIPProcessor.from_pretrained(
|
| 138 |
+
tmpdir, cls_token="(CLS)", sep_token="(SEP)", do_normalize=False
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
| 142 |
+
self.assertIsInstance(processor.tokenizer, BertTokenizerFast)
|
| 143 |
+
|
| 144 |
+
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
| 145 |
+
self.assertIsInstance(processor.image_processor, ChineseCLIPImageProcessor)
|
| 146 |
+
|
| 147 |
+
def test_image_processor(self):
|
| 148 |
+
image_processor = self.get_image_processor()
|
| 149 |
+
tokenizer = self.get_tokenizer()
|
| 150 |
+
|
| 151 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 152 |
+
|
| 153 |
+
image_input = self.prepare_image_inputs()
|
| 154 |
+
|
| 155 |
+
input_feat_extract = image_processor(image_input, return_tensors="np")
|
| 156 |
+
input_processor = processor(images=image_input, return_tensors="np")
|
| 157 |
+
|
| 158 |
+
for key in input_feat_extract.keys():
|
| 159 |
+
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
| 160 |
+
|
| 161 |
+
def test_tokenizer(self):
|
| 162 |
+
image_processor = self.get_image_processor()
|
| 163 |
+
tokenizer = self.get_tokenizer()
|
| 164 |
+
|
| 165 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 166 |
+
|
| 167 |
+
input_str = "Alexandra,T-shirt的价格是15便士。"
|
| 168 |
+
|
| 169 |
+
encoded_processor = processor(text=input_str)
|
| 170 |
+
|
| 171 |
+
encoded_tok = tokenizer(input_str)
|
| 172 |
+
|
| 173 |
+
for key in encoded_tok.keys():
|
| 174 |
+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
| 175 |
+
|
| 176 |
+
def test_processor(self):
|
| 177 |
+
image_processor = self.get_image_processor()
|
| 178 |
+
tokenizer = self.get_tokenizer()
|
| 179 |
+
|
| 180 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 181 |
+
|
| 182 |
+
input_str = "Alexandra,T-shirt的价格是15便士。"
|
| 183 |
+
image_input = self.prepare_image_inputs()
|
| 184 |
+
|
| 185 |
+
inputs = processor(text=input_str, images=image_input)
|
| 186 |
+
|
| 187 |
+
self.assertListEqual(list(inputs.keys()), ["input_ids", "token_type_ids", "attention_mask", "pixel_values"])
|
| 188 |
+
|
| 189 |
+
# test if it raises when no input is passed
|
| 190 |
+
with pytest.raises(ValueError):
|
| 191 |
+
processor()
|
| 192 |
+
|
| 193 |
+
def test_tokenizer_decode(self):
|
| 194 |
+
image_processor = self.get_image_processor()
|
| 195 |
+
tokenizer = self.get_tokenizer()
|
| 196 |
+
|
| 197 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 198 |
+
|
| 199 |
+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
| 200 |
+
|
| 201 |
+
decoded_processor = processor.batch_decode(predicted_ids)
|
| 202 |
+
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
| 203 |
+
|
| 204 |
+
self.assertListEqual(decoded_tok, decoded_processor)
|
| 205 |
+
|
| 206 |
+
def test_model_input_names(self):
|
| 207 |
+
image_processor = self.get_image_processor()
|
| 208 |
+
tokenizer = self.get_tokenizer()
|
| 209 |
+
|
| 210 |
+
processor = ChineseCLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 211 |
+
|
| 212 |
+
input_str = "Alexandra,T-shirt的价格是15便士。"
|
| 213 |
+
image_input = self.prepare_image_inputs()
|
| 214 |
+
|
| 215 |
+
inputs = processor(text=input_str, images=image_input)
|
| 216 |
+
|
| 217 |
+
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
docs/transformers/tests/models/clap/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/clap/test_feature_extraction_clap.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 HuggingFace Inc.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
import random
|
| 18 |
+
import unittest
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
|
| 23 |
+
from transformers import ClapFeatureExtractor
|
| 24 |
+
from transformers.testing_utils import require_torch, require_torchaudio
|
| 25 |
+
from transformers.trainer_utils import set_seed
|
| 26 |
+
from transformers.utils.import_utils import is_torch_available
|
| 27 |
+
|
| 28 |
+
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_torch_available():
|
| 32 |
+
import torch
|
| 33 |
+
|
| 34 |
+
global_rng = random.Random()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Copied from tests.models.whisper.test_feature_extraction_whisper.floats_list
|
| 38 |
+
def floats_list(shape, scale=1.0, rng=None, name=None):
|
| 39 |
+
"""Creates a random float32 tensor"""
|
| 40 |
+
if rng is None:
|
| 41 |
+
rng = global_rng
|
| 42 |
+
|
| 43 |
+
values = []
|
| 44 |
+
for batch_idx in range(shape[0]):
|
| 45 |
+
values.append([])
|
| 46 |
+
for _ in range(shape[1]):
|
| 47 |
+
values[-1].append(rng.random() * scale)
|
| 48 |
+
|
| 49 |
+
return values
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@require_torch
|
| 53 |
+
@require_torchaudio
|
| 54 |
+
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester with Whisper->Clap
|
| 55 |
+
class ClapFeatureExtractionTester:
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
parent,
|
| 59 |
+
batch_size=7,
|
| 60 |
+
min_seq_length=400,
|
| 61 |
+
max_seq_length=2000,
|
| 62 |
+
feature_size=10,
|
| 63 |
+
hop_length=160,
|
| 64 |
+
chunk_length=8,
|
| 65 |
+
padding_value=0.0,
|
| 66 |
+
sampling_rate=4_000,
|
| 67 |
+
return_attention_mask=False,
|
| 68 |
+
do_normalize=True,
|
| 69 |
+
):
|
| 70 |
+
self.parent = parent
|
| 71 |
+
self.batch_size = batch_size
|
| 72 |
+
self.min_seq_length = min_seq_length
|
| 73 |
+
self.max_seq_length = max_seq_length
|
| 74 |
+
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
| 75 |
+
self.padding_value = padding_value
|
| 76 |
+
self.sampling_rate = sampling_rate
|
| 77 |
+
self.return_attention_mask = return_attention_mask
|
| 78 |
+
self.do_normalize = do_normalize
|
| 79 |
+
self.feature_size = feature_size
|
| 80 |
+
self.chunk_length = chunk_length
|
| 81 |
+
self.hop_length = hop_length
|
| 82 |
+
|
| 83 |
+
def prepare_feat_extract_dict(self):
|
| 84 |
+
return {
|
| 85 |
+
"feature_size": self.feature_size,
|
| 86 |
+
"hop_length": self.hop_length,
|
| 87 |
+
"chunk_length": self.chunk_length,
|
| 88 |
+
"padding_value": self.padding_value,
|
| 89 |
+
"sampling_rate": self.sampling_rate,
|
| 90 |
+
"return_attention_mask": self.return_attention_mask,
|
| 91 |
+
"do_normalize": self.do_normalize,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
| 95 |
+
def _flatten(list_of_lists):
|
| 96 |
+
return list(itertools.chain(*list_of_lists))
|
| 97 |
+
|
| 98 |
+
if equal_length:
|
| 99 |
+
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
| 100 |
+
else:
|
| 101 |
+
# make sure that inputs increase in size
|
| 102 |
+
speech_inputs = [
|
| 103 |
+
floats_list((x, self.feature_size))
|
| 104 |
+
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
| 105 |
+
]
|
| 106 |
+
if numpify:
|
| 107 |
+
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
| 108 |
+
return speech_inputs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@require_torch
|
| 112 |
+
@require_torchaudio
|
| 113 |
+
class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
| 114 |
+
feature_extraction_class = ClapFeatureExtractor
|
| 115 |
+
|
| 116 |
+
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.setUp with Whisper->Clap
|
| 117 |
+
def setUp(self):
|
| 118 |
+
self.feat_extract_tester = ClapFeatureExtractionTester(self)
|
| 119 |
+
|
| 120 |
+
def test_call(self):
|
| 121 |
+
# Tests that all call wrap to encode_plus and batch_encode_plus
|
| 122 |
+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
| 123 |
+
# create three inputs of length 800, 1000, and 1200
|
| 124 |
+
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
| 125 |
+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
| 126 |
+
|
| 127 |
+
# Test feature size
|
| 128 |
+
input_features = feature_extractor(np_speech_inputs, padding="max_length", return_tensors="np").input_features
|
| 129 |
+
self.assertTrue(input_features.ndim == 4)
|
| 130 |
+
|
| 131 |
+
# Test not batched input
|
| 132 |
+
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
|
| 133 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
|
| 134 |
+
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
| 135 |
+
|
| 136 |
+
# Test batched
|
| 137 |
+
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
| 138 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
| 139 |
+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
| 140 |
+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
| 141 |
+
|
| 142 |
+
# Test 2-D numpy arrays are batched.
|
| 143 |
+
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
| 144 |
+
np_speech_inputs = np.asarray(speech_inputs)
|
| 145 |
+
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
| 146 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
| 147 |
+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
| 148 |
+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
| 149 |
+
|
| 150 |
+
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
|
| 151 |
+
def test_double_precision_pad(self):
|
| 152 |
+
import torch
|
| 153 |
+
|
| 154 |
+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
| 155 |
+
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
|
| 156 |
+
py_speech_inputs = np_speech_inputs.tolist()
|
| 157 |
+
|
| 158 |
+
for inputs in [py_speech_inputs, np_speech_inputs]:
|
| 159 |
+
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
|
| 160 |
+
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
| 161 |
+
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
| 162 |
+
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
| 163 |
+
|
| 164 |
+
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest._load_datasamples
|
| 165 |
+
def _load_datasamples(self, num_samples):
|
| 166 |
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 167 |
+
# automatic decoding with librispeech
|
| 168 |
+
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
| 169 |
+
|
| 170 |
+
return [x["array"] for x in speech_samples]
|
| 171 |
+
|
| 172 |
+
def test_integration_fusion_short_input(self):
|
| 173 |
+
# fmt: off
|
| 174 |
+
EXPECTED_INPUT_FEATURES = torch.tensor(
|
| 175 |
+
[
|
| 176 |
+
[
|
| 177 |
+
# "repeat"
|
| 178 |
+
[
|
| 179 |
+
-20.1049, -19.9764, -20.0731, -19.5055, -27.5018, -22.5761, -26.6071,
|
| 180 |
+
-29.0091, -26.4659, -26.4236, -28.8808, -31.9190, -32.4848, -34.1186,
|
| 181 |
+
-34.0340, -32.8803, -30.9895, -37.6238, -38.0347, -40.6263, -36.3496,
|
| 182 |
+
-42.2533, -32.9132, -27.7068, -29.3704, -30.3208, -22.5972, -27.1494,
|
| 183 |
+
-30.1975, -31.1005, -29.9372, -27.1917, -25.9806, -30.3489, -33.2380,
|
| 184 |
+
-31.9062, -36.5498, -32.8721, -30.5629, -27.4674, -22.2232, -22.5653,
|
| 185 |
+
-16.3868, -17.2713, -25.9738, -30.6256, -34.3766, -31.1292, -27.8950,
|
| 186 |
+
-27.0588, -25.6206, -23.0712, -26.6050, -28.0112, -32.6847, -34.3396,
|
| 187 |
+
-34.9738, -35.8463, -39.2324, -37.1188, -33.3705, -28.9230, -28.9112,
|
| 188 |
+
-28.6578
|
| 189 |
+
],
|
| 190 |
+
[
|
| 191 |
+
-36.7233, -30.0587, -24.8431, -18.4611, -16.8149, -23.9319, -32.8580,
|
| 192 |
+
-34.2264, -27.4332, -26.8027, -29.2721, -33.9033, -39.3403, -35.3232,
|
| 193 |
+
-26.8076, -28.6460, -35.2780, -36.0738, -35.4996, -37.7631, -39.5056,
|
| 194 |
+
-34.7112, -36.8741, -34.1066, -32.9474, -33.6604, -27.9937, -30.9594,
|
| 195 |
+
-26.2928, -32.0485, -29.2151, -29.2917, -32.7308, -29.6542, -31.1454,
|
| 196 |
+
-37.0088, -32.3388, -37.3086, -31.1024, -27.2889, -19.6788, -21.1488,
|
| 197 |
+
-19.5144, -14.8889, -21.2006, -24.7488, -27.7940, -31.1058, -27.5068,
|
| 198 |
+
-21.5737, -22.3780, -21.5151, -26.3086, -30.9223, -33.5043, -32.0307,
|
| 199 |
+
-37.3806, -41.6188, -45.6650, -40.5131, -32.5023, -26.7385, -26.3709,
|
| 200 |
+
-26.7761
|
| 201 |
+
]
|
| 202 |
+
],
|
| 203 |
+
[
|
| 204 |
+
# "repeatpad"
|
| 205 |
+
[
|
| 206 |
+
-25.7496, -24.9339, -24.1357, -23.1271, -23.7853, -26.1264, -29.1456,
|
| 207 |
+
-33.2060, -37.8179, -42.4833, -41.9386, -41.2164, -42.3566, -44.2575,
|
| 208 |
+
-40.0217, -36.6794, -36.6974, -38.7819, -42.0880, -45.5560, -39.9368,
|
| 209 |
+
-36.3219, -35.5981, -36.6434, -35.1851, -33.0684, -30.0437, -30.2010,
|
| 210 |
+
-34.3476, -42.1373, -38.8039, -37.3355, -40.4576, -41.0485, -40.6377,
|
| 211 |
+
-38.2275, -42.7481, -34.6084, -34.7048, -29.5149, -26.3935, -26.8952,
|
| 212 |
+
-34.1336, -26.2904, -28.2571, -32.5642, -36.7240, -35.5334, -38.2451,
|
| 213 |
+
-34.8177, -28.9754, -25.1096, -27.9768, -32.3184, -37.0269, -40.5136,
|
| 214 |
+
-40.8061, -36.4948, -40.3767, -38.9671, -38.3552, -34.1250, -30.9035,
|
| 215 |
+
-31.6112
|
| 216 |
+
],
|
| 217 |
+
[
|
| 218 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 219 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 220 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 221 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 222 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 223 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 224 |
+
-100., -100., -100., -100.
|
| 225 |
+
]
|
| 226 |
+
],
|
| 227 |
+
[
|
| 228 |
+
# None, same as "repeatpad"
|
| 229 |
+
[
|
| 230 |
+
-25.7496, -24.9339, -24.1357, -23.1271, -23.7853, -26.1264, -29.1456,
|
| 231 |
+
-33.2060, -37.8179, -42.4833, -41.9386, -41.2164, -42.3566, -44.2575,
|
| 232 |
+
-40.0217, -36.6794, -36.6974, -38.7819, -42.0880, -45.5560, -39.9368,
|
| 233 |
+
-36.3219, -35.5981, -36.6434, -35.1851, -33.0684, -30.0437, -30.2010,
|
| 234 |
+
-34.3476, -42.1373, -38.8039, -37.3355, -40.4576, -41.0485, -40.6377,
|
| 235 |
+
-38.2275, -42.7481, -34.6084, -34.7048, -29.5149, -26.3935, -26.8952,
|
| 236 |
+
-34.1336, -26.2904, -28.2571, -32.5642, -36.7240, -35.5334, -38.2451,
|
| 237 |
+
-34.8177, -28.9754, -25.1096, -27.9768, -32.3184, -37.0269, -40.5136,
|
| 238 |
+
-40.8061, -36.4948, -40.3767, -38.9671, -38.3552, -34.1250, -30.9035,
|
| 239 |
+
-31.6112
|
| 240 |
+
],
|
| 241 |
+
[
|
| 242 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 243 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 244 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 245 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 246 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 247 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 248 |
+
-100., -100., -100., -100.
|
| 249 |
+
]
|
| 250 |
+
],
|
| 251 |
+
[
|
| 252 |
+
# "pad"
|
| 253 |
+
[
|
| 254 |
+
-58.5260, -58.1155, -57.8623, -57.5059, -57.9178, -58.7171, -59.2343,
|
| 255 |
+
-59.9833, -60.9764, -62.0722, -63.5723, -65.7111, -67.5153, -68.7088,
|
| 256 |
+
-69.8325, -70.2987, -70.1548, -70.6233, -71.5702, -72.5159, -72.3821,
|
| 257 |
+
-70.1817, -67.0315, -64.1387, -62.2202, -61.0717, -60.4951, -61.6005,
|
| 258 |
+
-63.7358, -67.1400, -67.6185, -65.5635, -64.3593, -63.7138, -63.6209,
|
| 259 |
+
-66.4950, -72.6284, -63.3961, -56.8334, -52.7319, -50.6310, -51.3728,
|
| 260 |
+
-53.5619, -51.9190, -50.9708, -52.8684, -55.8073, -58.8227, -60.6991,
|
| 261 |
+
-57.0547, -52.7611, -51.4388, -54.4892, -60.8950, -66.1024, -72.4352,
|
| 262 |
+
-67.8538, -65.1463, -68.7588, -72.3080, -68.4864, -60.4688, -57.1516,
|
| 263 |
+
-60.9460
|
| 264 |
+
],
|
| 265 |
+
[
|
| 266 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 267 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 268 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 269 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 270 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 271 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 272 |
+
-100., -100., -100., -100.
|
| 273 |
+
]
|
| 274 |
+
]
|
| 275 |
+
]
|
| 276 |
+
)
|
| 277 |
+
# fmt: on
|
| 278 |
+
MEL_BIN = [[976, 977], [976, 977], [976, 977], [196, 197]]
|
| 279 |
+
input_speech = self._load_datasamples(1)
|
| 280 |
+
feature_extractor = ClapFeatureExtractor()
|
| 281 |
+
for padding, EXPECTED_VALUES, idx_in_mel in zip(
|
| 282 |
+
["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, MEL_BIN
|
| 283 |
+
):
|
| 284 |
+
input_features = feature_extractor(input_speech, return_tensors="pt", padding=padding).input_features
|
| 285 |
+
self.assertEqual(input_features.shape, (1, 4, 1001, 64))
|
| 286 |
+
|
| 287 |
+
torch.testing.assert_close(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], rtol=1e-4, atol=1e-4)
|
| 288 |
+
torch.testing.assert_close(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], rtol=1e-4, atol=1e-4)
|
| 289 |
+
|
| 290 |
+
self.assertTrue(torch.all(input_features[0, 0] == input_features[0, 1]))
|
| 291 |
+
self.assertTrue(torch.all(input_features[0, 0] == input_features[0, 2]))
|
| 292 |
+
self.assertTrue(torch.all(input_features[0, 0] == input_features[0, 3]))
|
| 293 |
+
|
| 294 |
+
def test_integration_rand_trunc_short_input(self):
|
| 295 |
+
# fmt: off
|
| 296 |
+
EXPECTED_INPUT_FEATURES = torch.tensor(
|
| 297 |
+
[
|
| 298 |
+
[
|
| 299 |
+
# "repeat"
|
| 300 |
+
[
|
| 301 |
+
-35.0483, -35.7865, -38.2884, -40.0220, -42.5349, -44.9489, -43.2228,
|
| 302 |
+
-44.6499, -47.6253, -49.6983, -50.2127, -52.5483, -52.2223, -51.9157,
|
| 303 |
+
-49.4082, -51.2024, -57.0476, -56.2803, -58.1618, -60.7474, -55.0389,
|
| 304 |
+
-60.9514, -59.3080, -50.4419, -47.8172, -48.7570, -55.2552, -44.5036,
|
| 305 |
+
-44.1148, -50.8218, -51.0968, -52.9408, -51.1037, -48.9789, -47.5897,
|
| 306 |
+
-52.0915, -55.4216, -54.1529, -58.0149, -58.0866, -52.7798, -52.6154,
|
| 307 |
+
-45.9144, -46.2008, -40.7603, -41.1703, -50.2250, -55.4112, -59.4818,
|
| 308 |
+
-54.5795, -53.5552, -51.3668, -49.8358, -50.3186, -54.0452, -57.6030,
|
| 309 |
+
-61.1589, -61.6415, -63.2756, -66.5890, -62.8543, -58.0665, -56.7203,
|
| 310 |
+
-56.7632
|
| 311 |
+
],
|
| 312 |
+
[
|
| 313 |
+
-47.1320, -37.9961, -34.0076, -36.7109, -47.9057, -48.4924, -43.8371,
|
| 314 |
+
-44.9728, -48.1689, -52.9141, -57.6077, -52.8520, -44.8502, -45.6764,
|
| 315 |
+
-51.8389, -56.4284, -54.6972, -53.4889, -55.6077, -58.7149, -60.3760,
|
| 316 |
+
-54.0136, -56.0730, -55.9870, -54.4017, -53.1094, -53.5640, -50.3064,
|
| 317 |
+
-49.9520, -49.3239, -48.1668, -53.4852, -50.4561, -50.8688, -55.1970,
|
| 318 |
+
-51.5538, -53.0260, -59.6933, -54.8183, -59.5895, -55.9589, -50.3761,
|
| 319 |
+
-44.1282, -44.1463, -43.8540, -39.1168, -45.3893, -49.5542, -53.1505,
|
| 320 |
+
-55.2870, -50.3921, -46.8511, -47.4444, -49.5633, -56.0034, -59.0815,
|
| 321 |
+
-59.0018, -63.7589, -69.5745, -71.5789, -64.0498, -56.0558, -54.3475,
|
| 322 |
+
-54.7004
|
| 323 |
+
]
|
| 324 |
+
],
|
| 325 |
+
[
|
| 326 |
+
# "repeatpad"
|
| 327 |
+
[
|
| 328 |
+
-40.3184, -39.7186, -39.8807, -41.6508, -45.3613, -50.4785, -57.0297,
|
| 329 |
+
-60.4944, -59.1642, -58.9495, -60.4661, -62.5300, -58.4759, -55.2865,
|
| 330 |
+
-54.8973, -56.0780, -57.5482, -59.6557, -64.3309, -65.0330, -59.4941,
|
| 331 |
+
-56.8552, -55.0519, -55.9817, -56.9739, -55.2827, -54.5312, -51.4141,
|
| 332 |
+
-50.4289, -51.9131, -57.5821, -63.9979, -59.9180, -58.9489, -62.3247,
|
| 333 |
+
-62.6975, -63.7948, -60.5250, -64.6107, -58.7905, -57.0229, -54.3084,
|
| 334 |
+
-49.8445, -50.4459, -57.0172, -50.6425, -52.5992, -57.4207, -61.6358,
|
| 335 |
+
-60.6540, -63.1968, -57.4360, -52.3263, -51.7695, -57.1946, -62.9610,
|
| 336 |
+
-66.7359, -67.0335, -63.7440, -68.1775, -66.3798, -62.8650, -59.8972,
|
| 337 |
+
-59.3139
|
| 338 |
+
],
|
| 339 |
+
[
|
| 340 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 341 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 342 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 343 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 344 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 345 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 346 |
+
-100., -100., -100., -100.
|
| 347 |
+
]
|
| 348 |
+
],
|
| 349 |
+
[
|
| 350 |
+
# None, same as "repeatpad"
|
| 351 |
+
[
|
| 352 |
+
-40.3184, -39.7186, -39.8807, -41.6508, -45.3613, -50.4785, -57.0297,
|
| 353 |
+
-60.4944, -59.1642, -58.9495, -60.4661, -62.5300, -58.4759, -55.2865,
|
| 354 |
+
-54.8973, -56.0780, -57.5482, -59.6557, -64.3309, -65.0330, -59.4941,
|
| 355 |
+
-56.8552, -55.0519, -55.9817, -56.9739, -55.2827, -54.5312, -51.4141,
|
| 356 |
+
-50.4289, -51.9131, -57.5821, -63.9979, -59.9180, -58.9489, -62.3247,
|
| 357 |
+
-62.6975, -63.7948, -60.5250, -64.6107, -58.7905, -57.0229, -54.3084,
|
| 358 |
+
-49.8445, -50.4459, -57.0172, -50.6425, -52.5992, -57.4207, -61.6358,
|
| 359 |
+
-60.6540, -63.1968, -57.4360, -52.3263, -51.7695, -57.1946, -62.9610,
|
| 360 |
+
-66.7359, -67.0335, -63.7440, -68.1775, -66.3798, -62.8650, -59.8972,
|
| 361 |
+
-59.3139
|
| 362 |
+
],
|
| 363 |
+
[
|
| 364 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 365 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 366 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 367 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 368 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 369 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 370 |
+
-100., -100., -100., -100.
|
| 371 |
+
]
|
| 372 |
+
],
|
| 373 |
+
[
|
| 374 |
+
# "pad"
|
| 375 |
+
[
|
| 376 |
+
-73.3190, -73.6349, -74.1451, -74.8539, -75.7476, -76.5438, -78.5540,
|
| 377 |
+
-80.1339, -81.8911, -83.7560, -85.5387, -86.7466, -88.2072, -88.6090,
|
| 378 |
+
-88.8243, -89.0784, -89.4364, -89.8179, -91.3146, -92.2833, -91.7221,
|
| 379 |
+
-90.9440, -88.1315, -86.2425, -84.2281, -82.4893, -81.5993, -81.1328,
|
| 380 |
+
-81.5759, -83.1068, -85.6525, -88.9520, -88.9187, -87.2703, -86.3052,
|
| 381 |
+
-85.7188, -85.8802, -87.9996, -95.0464, -88.0133, -80.8561, -76.5597,
|
| 382 |
+
-74.2816, -74.8109, -77.3615, -76.0719, -75.3426, -77.6428, -80.9663,
|
| 383 |
+
-84.5275, -84.9907, -80.5205, -77.2851, -78.6259, -84.7740, -91.4535,
|
| 384 |
+
-98.1894, -94.3872, -92.3735, -97.6807, -98.1501, -91.4344, -85.2842,
|
| 385 |
+
-88.4338
|
| 386 |
+
],
|
| 387 |
+
[
|
| 388 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 389 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 390 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 391 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 392 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 393 |
+
-100., -100., -100., -100., -100., -100., -100., -100., -100., -100.,
|
| 394 |
+
-100., -100., -100., -100.
|
| 395 |
+
]
|
| 396 |
+
]
|
| 397 |
+
]
|
| 398 |
+
)
|
| 399 |
+
# fmt: on
|
| 400 |
+
MEL_BIN = [[976, 977], [976, 977], [976, 977], [196, 197]]
|
| 401 |
+
input_speech = self._load_datasamples(1)
|
| 402 |
+
feature_extractor = ClapFeatureExtractor()
|
| 403 |
+
for padding, EXPECTED_VALUES, idx_in_mel in zip(
|
| 404 |
+
["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, MEL_BIN
|
| 405 |
+
):
|
| 406 |
+
input_features = feature_extractor(
|
| 407 |
+
input_speech, return_tensors="pt", truncation="rand_trunc", padding=padding
|
| 408 |
+
).input_features
|
| 409 |
+
self.assertEqual(input_features.shape, (1, 1, 1001, 64))
|
| 410 |
+
torch.testing.assert_close(input_features[0, 0, idx_in_mel[0]], EXPECTED_VALUES[0], rtol=1e-4, atol=1e-4)
|
| 411 |
+
torch.testing.assert_close(input_features[0, 0, idx_in_mel[1]], EXPECTED_VALUES[1], rtol=1e-4, atol=1e-4)
|
| 412 |
+
|
| 413 |
+
def test_integration_fusion_long_input(self):
|
| 414 |
+
# fmt: off
|
| 415 |
+
EXPECTED_INPUT_FEATURES = torch.tensor(
|
| 416 |
+
[
|
| 417 |
+
[
|
| 418 |
+
-11.1830, -10.1894, -8.6051, -4.8578, -1.3268, -8.4606, -14.5453,
|
| 419 |
+
-9.2017, 0.5781, 16.2129, 14.8289, 3.6326, -3.8794, -6.5544,
|
| 420 |
+
-2.4408, 1.9531, 6.0967, 1.7590, -7.6730, -6.1571, 2.0052,
|
| 421 |
+
16.6694, 20.6447, 21.2145, 13.4972, 15.9043, 16.8987, 4.1766,
|
| 422 |
+
11.9428, 21.2372, 12.3016, 4.8604, 6.7241, 1.8543, 4.9235,
|
| 423 |
+
5.3188, -0.9897, -1.2416, -6.5864, 2.9529, 2.9274, 6.4753,
|
| 424 |
+
10.2300, 11.2127, 3.4042, -1.0055, -6.0475, -6.7524, -3.9801,
|
| 425 |
+
-1.4434, 0.4740, -0.1584, -4.5457, -8.5746, -8.8428, -13.1475,
|
| 426 |
+
-9.6079, -8.5798, -4.1143, -3.7966, -7.1651, -6.1517, -8.0258,
|
| 427 |
+
-12.1486
|
| 428 |
+
],
|
| 429 |
+
[
|
| 430 |
+
-10.2017, -7.9924, -5.9517, -3.9372, -1.9735, -4.3130, 16.1647,
|
| 431 |
+
25.0592, 23.5532, 14.4974, -7.0778, -10.2262, 6.4782, 20.3454,
|
| 432 |
+
19.4269, 1.7976, -16.5070, 4.9380, 12.3390, 6.9285, -13.6325,
|
| 433 |
+
-8.5298, 1.0839, -5.9629, -8.4812, 3.1331, -2.0963, -16.6046,
|
| 434 |
+
-14.0070, -17.5707, -13.2080, -17.2168, -17.7770, -12.1111, -18.6184,
|
| 435 |
+
-17.1897, -13.9801, -12.0426, -23.5400, -25.6823, -23.5813, -18.7847,
|
| 436 |
+
-20.5473, -25.6458, -19.7585, -27.6007, -28.9276, -24.8948, -25.4458,
|
| 437 |
+
-22.2807, -19.6613, -19.2669, -15.7813, -19.6821, -24.3439, -22.2598,
|
| 438 |
+
-28.2631, -30.1017, -32.7646, -33.6525, -27.5639, -22.0548, -27.8054,
|
| 439 |
+
-29.6947
|
| 440 |
+
],
|
| 441 |
+
[
|
| 442 |
+
-9.2078, -7.2963, -6.2095, -7.9959, -2.9280, -11.1843, -6.1490,
|
| 443 |
+
5.0733, 19.2957, 21.4578, 14.6803, -3.3153, -6.3334, -2.3542,
|
| 444 |
+
6.9509, 15.2965, 14.6620, 5.2075, -0.0873, 1.1919, 18.1986,
|
| 445 |
+
20.8470, 10.8035, 2.2516, 7.6905, 7.7427, -1.2543, -5.0018,
|
| 446 |
+
0.9809, -2.1584, -5.4580, -5.4760, -11.8888, -9.0605, -8.4638,
|
| 447 |
+
-9.9897, -0.0540, -5.1629, 0.0483, -4.1504, -4.8140, -7.8236,
|
| 448 |
+
-9.0622, -10.1742, -8.9597, -11.5380, -16.5603, -17.1858, -17.5032,
|
| 449 |
+
-20.9326, -23.9543, -25.2602, -25.3429, -27.4536, -26.8859, -22.7852,
|
| 450 |
+
-25.8288, -24.8399, -23.8893, -24.2096, -26.5415, -23.7281, -25.6851,
|
| 451 |
+
-22.3629
|
| 452 |
+
],
|
| 453 |
+
[
|
| 454 |
+
1.3448, 2.9883, 4.0366, -0.8019, -10.4191, -10.0883, -4.3812,
|
| 455 |
+
0.8136, 2.1579, 0.0832, 1.0949, -0.9759, -5.5319, -4.6009,
|
| 456 |
+
-6.5452, -14.9155, -20.1584, -9.3611, -2.4271, 1.4031, 4.9910,
|
| 457 |
+
8.6916, 8.6785, 10.1973, 9.9029, 5.3840, 7.5336, 5.2803,
|
| 458 |
+
2.8144, -0.3138, 2.2216, 5.7328, 7.5574, 7.7402, 1.0681,
|
| 459 |
+
3.1049, 7.0742, 6.5588, 7.3712, 5.7881, 8.6874, 8.7725,
|
| 460 |
+
2.8133, -4.5809, -6.1317, -5.1719, -5.0192, -9.0977, -10.9391,
|
| 461 |
+
-6.0769, 1.6016, -0.8965, -7.2252, -7.8632, -11.4468, -11.7446,
|
| 462 |
+
-10.7447, -7.0601, -2.7748, -4.1798, -2.8433, -3.1352, 0.8097,
|
| 463 |
+
6.4212
|
| 464 |
+
]
|
| 465 |
+
]
|
| 466 |
+
)
|
| 467 |
+
# fmt: on
|
| 468 |
+
MEL_BIN = 963
|
| 469 |
+
input_speech = torch.cat([torch.tensor(x) for x in self._load_datasamples(5)])
|
| 470 |
+
feature_extractor = ClapFeatureExtractor()
|
| 471 |
+
for padding, EXPECTED_VALUES, block_idx in zip(
|
| 472 |
+
["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, [1, 2, 0, 3]
|
| 473 |
+
):
|
| 474 |
+
set_seed(987654321)
|
| 475 |
+
input_features = feature_extractor(input_speech, return_tensors="pt", padding=padding).input_features
|
| 476 |
+
self.assertEqual(input_features.shape, (1, 4, 1001, 64))
|
| 477 |
+
torch.testing.assert_close(input_features[0, block_idx, MEL_BIN], EXPECTED_VALUES, rtol=1e-3, atol=1e-3)
|
| 478 |
+
|
| 479 |
+
def test_integration_rand_trunc_long_input(self):
|
| 480 |
+
# fmt: off
|
| 481 |
+
EXPECTED_INPUT_FEATURES = torch.tensor(
|
| 482 |
+
[
|
| 483 |
+
[
|
| 484 |
+
-35.4022, -32.7555, -31.2004, -32.7764, -42.5770, -41.6339, -43.1630,
|
| 485 |
+
-44.5080, -44.3029, -48.9628, -39.5022, -39.2105, -43.1350, -43.2195,
|
| 486 |
+
-48.4894, -52.2344, -57.6891, -52.2228, -45.5155, -44.2893, -43.4697,
|
| 487 |
+
-46.6702, -43.7490, -40.4819, -42.7275, -46.3434, -46.8412, -41.2003,
|
| 488 |
+
-43.1681, -46.2948, -46.1925, -47.8333, -45.6812, -44.9182, -41.7786,
|
| 489 |
+
-43.3809, -44.3199, -42.8814, -45.4771, -46.7114, -46.9746, -42.7090,
|
| 490 |
+
-41.6057, -38.3965, -40.1980, -41.0263, -34.1256, -28.3289, -29.0201,
|
| 491 |
+
-30.4453, -29.5561, -30.1734, -25.9406, -19.0897, -15.8452, -20.1351,
|
| 492 |
+
-23.6515, -23.1194, -17.1845, -19.4399, -23.6527, -22.8768, -20.7279,
|
| 493 |
+
-22.7864
|
| 494 |
+
],
|
| 495 |
+
[
|
| 496 |
+
-35.7719, -27.2566, -23.6964, -27.5521, 0.2510, 7.4391, 1.3917,
|
| 497 |
+
-13.3417, -28.1758, -17.0856, -5.7723, -0.8000, -7.8832, -15.5548,
|
| 498 |
+
-30.5935, -24.7571, -13.7009, -10.3432, -21.2464, -24.8118, -19.4080,
|
| 499 |
+
-14.9779, -11.7991, -18.4485, -20.1982, -17.3652, -20.6328, -28.2967,
|
| 500 |
+
-25.7819, -21.8962, -28.5083, -29.5719, -30.2120, -35.7033, -31.8218,
|
| 501 |
+
-34.0408, -37.7744, -33.9653, -31.3009, -30.9063, -28.6153, -32.2202,
|
| 502 |
+
-28.5456, -28.8579, -32.5170, -37.9152, -43.0052, -46.4849, -44.0786,
|
| 503 |
+
-39.1933, -33.2757, -31.6313, -42.6386, -52.3679, -53.5785, -55.6444,
|
| 504 |
+
-47.0050, -47.6459, -56.6361, -60.6781, -61.5244, -55.8272, -60.4832,
|
| 505 |
+
-58.1897
|
| 506 |
+
],
|
| 507 |
+
[
|
| 508 |
+
-38.2686, -36.6285, -32.5835, -35.1693, -37.7938, -37.4035, -35.3132,
|
| 509 |
+
-35.6083, -36.3609, -40.9472, -36.7846, -36.1544, -38.9076, -39.3618,
|
| 510 |
+
-35.4953, -34.2809, -39.9466, -39.7433, -34.8347, -37.5674, -41.5689,
|
| 511 |
+
-38.9161, -34.3947, -30.2924, -30.4841, -34.5831, -28.9261, -24.8849,
|
| 512 |
+
-31.2324, -27.1622, -27.2107, -25.9385, -30.1691, -30.9223, -23.9495,
|
| 513 |
+
-25.6047, -26.7119, -28.5523, -27.7481, -32.8427, -35.4650, -31.0399,
|
| 514 |
+
-31.2073, -30.5163, -22.9819, -20.8892, -19.2510, -24.7905, -28.9426,
|
| 515 |
+
-28.1998, -26.7386, -25.0140, -27.9223, -32.9913, -33.1864, -34.9742,
|
| 516 |
+
-38.5995, -39.6990, -29.3203, -22.4697, -25.6415, -33.5608, -33.0945,
|
| 517 |
+
-27.1716
|
| 518 |
+
],
|
| 519 |
+
[
|
| 520 |
+
-33.2015, -28.7741, -21.9457, -23.4888, -32.1072, -8.6307, 3.2724,
|
| 521 |
+
5.9157, -0.9221, -30.1814, -31.0015, -27.4508, -27.0477, -9.5342,
|
| 522 |
+
0.3221, 0.6511, -7.1596, -25.9707, -32.8924, -32.2300, -13.8974,
|
| 523 |
+
-0.4895, 0.9168, -10.7663, -27.1176, -35.0829, -11.6859, -4.8855,
|
| 524 |
+
-11.8898, -26.6167, -5.6192, -3.8443, -19.7947, -14.4101, -8.6236,
|
| 525 |
+
-21.2458, -21.0801, -17.9136, -24.4663, -18.6333, -24.8085, -15.5854,
|
| 526 |
+
-15.4344, -11.5046, -22.3625, -27.3387, -32.4353, -30.9670, -31.3789,
|
| 527 |
+
-35.4044, -34.4591, -25.2433, -28.0773, -33.8736, -33.0224, -33.3155,
|
| 528 |
+
-38.5302, -39.2741, -36.6395, -34.7729, -32.4483, -42.4001, -49.2857,
|
| 529 |
+
-39.1682
|
| 530 |
+
]
|
| 531 |
+
]
|
| 532 |
+
)
|
| 533 |
+
# fmt: on
|
| 534 |
+
MEL_BIN = 963
|
| 535 |
+
SEEDS = [987654321, 1234, 666, 5555]
|
| 536 |
+
input_speech = torch.cat([torch.tensor(x) for x in self._load_datasamples(5)])
|
| 537 |
+
feature_extractor = ClapFeatureExtractor()
|
| 538 |
+
for padding, EXPECTED_VALUES, seed in zip(
|
| 539 |
+
["repeat", "repeatpad", None, "pad"], EXPECTED_INPUT_FEATURES, SEEDS
|
| 540 |
+
):
|
| 541 |
+
set_seed(seed)
|
| 542 |
+
input_features = feature_extractor(
|
| 543 |
+
input_speech, return_tensors="pt", truncation="rand_trunc", padding=padding
|
| 544 |
+
).input_features
|
| 545 |
+
self.assertEqual(input_features.shape, (1, 1, 1001, 64))
|
| 546 |
+
torch.testing.assert_close(input_features[0, 0, MEL_BIN], EXPECTED_VALUES, rtol=1e-4, atol=1e-4)
|
docs/transformers/tests/models/clap/test_modeling_clap.py
ADDED
|
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch CLAP model."""
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
|
| 24 |
+
from transformers import ClapAudioConfig, ClapConfig, ClapProcessor, ClapTextConfig
|
| 25 |
+
from transformers.testing_utils import require_torch, slow, torch_device
|
| 26 |
+
from transformers.utils import is_torch_available
|
| 27 |
+
|
| 28 |
+
from ...test_configuration_common import ConfigTester
|
| 29 |
+
from ...test_modeling_common import (
|
| 30 |
+
ModelTesterMixin,
|
| 31 |
+
_config_zero_init,
|
| 32 |
+
floats_tensor,
|
| 33 |
+
ids_tensor,
|
| 34 |
+
random_attention_mask,
|
| 35 |
+
)
|
| 36 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if is_torch_available():
|
| 40 |
+
import torch
|
| 41 |
+
from torch import nn
|
| 42 |
+
|
| 43 |
+
from transformers import (
|
| 44 |
+
ClapAudioModel,
|
| 45 |
+
ClapAudioModelWithProjection,
|
| 46 |
+
ClapModel,
|
| 47 |
+
ClapTextModel,
|
| 48 |
+
ClapTextModelWithProjection,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ClapAudioModelTester:
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
parent,
|
| 56 |
+
batch_size=12,
|
| 57 |
+
image_size=60,
|
| 58 |
+
num_mel_bins=16,
|
| 59 |
+
window_size=4,
|
| 60 |
+
spec_size=64,
|
| 61 |
+
patch_size=2,
|
| 62 |
+
patch_stride=2,
|
| 63 |
+
seq_length=16,
|
| 64 |
+
freq_ratio=2,
|
| 65 |
+
num_channels=3,
|
| 66 |
+
is_training=True,
|
| 67 |
+
hidden_size=32,
|
| 68 |
+
patch_embeds_hidden_size=16,
|
| 69 |
+
projection_dim=32,
|
| 70 |
+
depths=[2, 2],
|
| 71 |
+
num_hidden_layers=2,
|
| 72 |
+
num_heads=[2, 2],
|
| 73 |
+
intermediate_size=37,
|
| 74 |
+
dropout=0.1,
|
| 75 |
+
attention_dropout=0.1,
|
| 76 |
+
initializer_range=0.02,
|
| 77 |
+
scope=None,
|
| 78 |
+
):
|
| 79 |
+
self.parent = parent
|
| 80 |
+
self.batch_size = batch_size
|
| 81 |
+
self.image_size = image_size
|
| 82 |
+
self.num_mel_bins = num_mel_bins
|
| 83 |
+
self.window_size = window_size
|
| 84 |
+
self.patch_size = patch_size
|
| 85 |
+
self.num_channels = num_channels
|
| 86 |
+
self.is_training = is_training
|
| 87 |
+
self.hidden_size = hidden_size
|
| 88 |
+
self.projection_dim = projection_dim
|
| 89 |
+
self.num_hidden_layers = num_hidden_layers
|
| 90 |
+
self.depths = depths
|
| 91 |
+
self.num_heads = num_heads
|
| 92 |
+
self.num_attention_heads = num_heads[0]
|
| 93 |
+
self.seq_length = seq_length
|
| 94 |
+
self.spec_size = spec_size
|
| 95 |
+
self.freq_ratio = freq_ratio
|
| 96 |
+
self.patch_stride = patch_stride
|
| 97 |
+
self.patch_embeds_hidden_size = patch_embeds_hidden_size
|
| 98 |
+
self.intermediate_size = intermediate_size
|
| 99 |
+
self.dropout = dropout
|
| 100 |
+
self.attention_dropout = attention_dropout
|
| 101 |
+
self.initializer_range = initializer_range
|
| 102 |
+
self.scope = scope
|
| 103 |
+
|
| 104 |
+
def prepare_config_and_inputs(self):
|
| 105 |
+
input_features = floats_tensor([self.batch_size, 1, self.hidden_size, self.num_mel_bins])
|
| 106 |
+
config = self.get_config()
|
| 107 |
+
|
| 108 |
+
return config, input_features
|
| 109 |
+
|
| 110 |
+
def get_config(self):
|
| 111 |
+
return ClapAudioConfig(
|
| 112 |
+
image_size=self.image_size,
|
| 113 |
+
patch_size=self.patch_size,
|
| 114 |
+
num_mel_bins=self.num_mel_bins,
|
| 115 |
+
window_size=self.window_size,
|
| 116 |
+
num_channels=self.num_channels,
|
| 117 |
+
hidden_size=self.hidden_size,
|
| 118 |
+
patch_stride=self.patch_stride,
|
| 119 |
+
projection_dim=self.projection_dim,
|
| 120 |
+
depths=self.depths,
|
| 121 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 122 |
+
num_attention_heads=self.num_heads,
|
| 123 |
+
intermediate_size=self.intermediate_size,
|
| 124 |
+
dropout=self.dropout,
|
| 125 |
+
attention_dropout=self.attention_dropout,
|
| 126 |
+
initializer_range=self.initializer_range,
|
| 127 |
+
spec_size=self.spec_size,
|
| 128 |
+
freq_ratio=self.freq_ratio,
|
| 129 |
+
patch_embeds_hidden_size=self.patch_embeds_hidden_size,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def create_and_check_model(self, config, input_features):
|
| 133 |
+
model = ClapAudioModel(config=config)
|
| 134 |
+
model.to(torch_device)
|
| 135 |
+
model.eval()
|
| 136 |
+
with torch.no_grad():
|
| 137 |
+
result = model(input_features)
|
| 138 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 139 |
+
|
| 140 |
+
def create_and_check_model_with_projection(self, config, input_features):
|
| 141 |
+
model = ClapAudioModelWithProjection(config=config)
|
| 142 |
+
model.to(torch_device)
|
| 143 |
+
model.eval()
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
result = model(input_features)
|
| 146 |
+
self.parent.assertEqual(result.audio_embeds.shape, (self.batch_size, self.projection_dim))
|
| 147 |
+
|
| 148 |
+
def prepare_config_and_inputs_for_common(self):
|
| 149 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 150 |
+
config, input_features = config_and_inputs
|
| 151 |
+
inputs_dict = {"input_features": input_features}
|
| 152 |
+
return config, inputs_dict
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@require_torch
|
| 156 |
+
class ClapAudioModelTest(ModelTesterMixin, unittest.TestCase):
|
| 157 |
+
"""
|
| 158 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CLAP does not use input_ids, inputs_embeds,
|
| 159 |
+
attention_mask and seq_length.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
all_model_classes = (ClapAudioModel, ClapAudioModelWithProjection) if is_torch_available() else ()
|
| 163 |
+
fx_compatible = False
|
| 164 |
+
test_pruning = False
|
| 165 |
+
test_resize_embeddings = False
|
| 166 |
+
test_head_masking = False
|
| 167 |
+
|
| 168 |
+
def setUp(self):
|
| 169 |
+
self.model_tester = ClapAudioModelTester(self)
|
| 170 |
+
self.config_tester = ConfigTester(self, config_class=ClapAudioConfig, has_text_modality=False, hidden_size=37)
|
| 171 |
+
|
| 172 |
+
def test_config(self):
|
| 173 |
+
self.config_tester.run_common_tests()
|
| 174 |
+
|
| 175 |
+
@unittest.skip(reason="ClapAudioModel does not use inputs_embeds")
|
| 176 |
+
def test_inputs_embeds(self):
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
def test_model_get_set_embeddings(self):
|
| 180 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 181 |
+
|
| 182 |
+
for model_class in self.all_model_classes:
|
| 183 |
+
model = model_class(config)
|
| 184 |
+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
| 185 |
+
x = model.get_output_embeddings()
|
| 186 |
+
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
| 187 |
+
|
| 188 |
+
def test_hidden_states_output(self):
|
| 189 |
+
def check_hidden_states_output(inputs_dict, config, model_class):
|
| 190 |
+
model = model_class(config)
|
| 191 |
+
model.to(torch_device)
|
| 192 |
+
model.eval()
|
| 193 |
+
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 196 |
+
|
| 197 |
+
hidden_states = outputs.hidden_states
|
| 198 |
+
|
| 199 |
+
expected_num_layers = getattr(
|
| 200 |
+
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
| 201 |
+
)
|
| 202 |
+
self.assertEqual(len(hidden_states), expected_num_layers)
|
| 203 |
+
|
| 204 |
+
self.assertListEqual(
|
| 205 |
+
list(hidden_states[0].shape[-2:]),
|
| 206 |
+
[2 * self.model_tester.patch_embeds_hidden_size, 2 * self.model_tester.patch_embeds_hidden_size],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 210 |
+
|
| 211 |
+
for model_class in self.all_model_classes:
|
| 212 |
+
inputs_dict["output_hidden_states"] = True
|
| 213 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 214 |
+
|
| 215 |
+
# check that output_hidden_states also work using config
|
| 216 |
+
del inputs_dict["output_hidden_states"]
|
| 217 |
+
config.output_hidden_states = True
|
| 218 |
+
|
| 219 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 220 |
+
|
| 221 |
+
@unittest.skip(reason="ClapAudioModel does not output any loss term in the forward pass")
|
| 222 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 223 |
+
pass
|
| 224 |
+
|
| 225 |
+
def test_forward_signature(self):
|
| 226 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 227 |
+
|
| 228 |
+
for model_class in self.all_model_classes:
|
| 229 |
+
model = model_class(config)
|
| 230 |
+
signature = inspect.signature(model.forward)
|
| 231 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 232 |
+
arg_names = [*signature.parameters.keys()]
|
| 233 |
+
|
| 234 |
+
expected_arg_names = ["input_features"]
|
| 235 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 236 |
+
|
| 237 |
+
def test_model(self):
|
| 238 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 239 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 240 |
+
|
| 241 |
+
def test_model_with_projection(self):
|
| 242 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 243 |
+
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
| 244 |
+
|
| 245 |
+
@unittest.skip(reason="ClapAudioModel does not output any loss term in the forward pass")
|
| 246 |
+
def test_training(self):
|
| 247 |
+
pass
|
| 248 |
+
|
| 249 |
+
@unittest.skip(reason="ClapAudioModel does not output any loss term in the forward pass")
|
| 250 |
+
def test_training_gradient_checkpointing(self):
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
@unittest.skip(
|
| 254 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 255 |
+
)
|
| 256 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 257 |
+
pass
|
| 258 |
+
|
| 259 |
+
@unittest.skip(
|
| 260 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 261 |
+
)
|
| 262 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
@slow
|
| 266 |
+
def test_model_from_pretrained(self):
|
| 267 |
+
model_name = "laion/clap-htsat-fused"
|
| 268 |
+
model = ClapAudioModel.from_pretrained(model_name)
|
| 269 |
+
self.assertIsNotNone(model)
|
| 270 |
+
|
| 271 |
+
@slow
|
| 272 |
+
def test_model_with_projection_from_pretrained(self):
|
| 273 |
+
model_name = "laion/clap-htsat-fused"
|
| 274 |
+
model = ClapAudioModelWithProjection.from_pretrained(model_name)
|
| 275 |
+
self.assertIsNotNone(model)
|
| 276 |
+
self.assertTrue(hasattr(model, "audio_projection"))
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class ClapTextModelTester:
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
parent,
|
| 283 |
+
batch_size=12,
|
| 284 |
+
seq_length=7,
|
| 285 |
+
is_training=True,
|
| 286 |
+
use_input_mask=True,
|
| 287 |
+
use_labels=True,
|
| 288 |
+
vocab_size=99,
|
| 289 |
+
hidden_size=32,
|
| 290 |
+
projection_dim=32,
|
| 291 |
+
num_hidden_layers=2,
|
| 292 |
+
num_attention_heads=4,
|
| 293 |
+
intermediate_size=37,
|
| 294 |
+
dropout=0.1,
|
| 295 |
+
attention_dropout=0.1,
|
| 296 |
+
max_position_embeddings=512,
|
| 297 |
+
initializer_range=0.02,
|
| 298 |
+
scope=None,
|
| 299 |
+
projection_hidden_act="relu",
|
| 300 |
+
):
|
| 301 |
+
self.parent = parent
|
| 302 |
+
self.batch_size = batch_size
|
| 303 |
+
self.seq_length = seq_length
|
| 304 |
+
self.is_training = is_training
|
| 305 |
+
self.use_input_mask = use_input_mask
|
| 306 |
+
self.use_labels = use_labels
|
| 307 |
+
self.vocab_size = vocab_size
|
| 308 |
+
self.hidden_size = hidden_size
|
| 309 |
+
self.projection_dim = projection_dim
|
| 310 |
+
self.num_hidden_layers = num_hidden_layers
|
| 311 |
+
self.num_attention_heads = num_attention_heads
|
| 312 |
+
self.intermediate_size = intermediate_size
|
| 313 |
+
self.dropout = dropout
|
| 314 |
+
self.attention_dropout = attention_dropout
|
| 315 |
+
self.max_position_embeddings = max_position_embeddings
|
| 316 |
+
self.initializer_range = initializer_range
|
| 317 |
+
self.scope = scope
|
| 318 |
+
self.projection_hidden_act = projection_hidden_act
|
| 319 |
+
|
| 320 |
+
def prepare_config_and_inputs(self):
|
| 321 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 322 |
+
|
| 323 |
+
input_mask = None
|
| 324 |
+
if self.use_input_mask:
|
| 325 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 326 |
+
|
| 327 |
+
if input_mask is not None:
|
| 328 |
+
batch_size, seq_length = input_mask.shape
|
| 329 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 330 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 331 |
+
input_mask[batch_idx, :start_index] = 1
|
| 332 |
+
input_mask[batch_idx, start_index:] = 0
|
| 333 |
+
|
| 334 |
+
config = self.get_config()
|
| 335 |
+
|
| 336 |
+
return config, input_ids, input_mask
|
| 337 |
+
|
| 338 |
+
def get_config(self):
|
| 339 |
+
return ClapTextConfig(
|
| 340 |
+
vocab_size=self.vocab_size,
|
| 341 |
+
hidden_size=self.hidden_size,
|
| 342 |
+
projection_dim=self.projection_dim,
|
| 343 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 344 |
+
num_attention_heads=self.num_attention_heads,
|
| 345 |
+
intermediate_size=self.intermediate_size,
|
| 346 |
+
dropout=self.dropout,
|
| 347 |
+
attention_dropout=self.attention_dropout,
|
| 348 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 349 |
+
initializer_range=self.initializer_range,
|
| 350 |
+
projection_hidden_act=self.projection_hidden_act,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def create_and_check_model(self, config, input_ids, input_mask):
|
| 354 |
+
model = ClapTextModel(config=config)
|
| 355 |
+
model.to(torch_device)
|
| 356 |
+
model.eval()
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 359 |
+
result = model(input_ids)
|
| 360 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 361 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 362 |
+
|
| 363 |
+
def create_and_check_model_with_projection(self, config, input_ids, input_mask):
|
| 364 |
+
model = ClapTextModelWithProjection(config=config)
|
| 365 |
+
model.to(torch_device)
|
| 366 |
+
model.eval()
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 369 |
+
result = model(input_ids)
|
| 370 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 371 |
+
self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim))
|
| 372 |
+
|
| 373 |
+
def prepare_config_and_inputs_for_common(self):
|
| 374 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 375 |
+
config, input_ids, input_mask = config_and_inputs
|
| 376 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 377 |
+
return config, inputs_dict
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@require_torch
|
| 381 |
+
class ClapTextModelTest(ModelTesterMixin, unittest.TestCase):
|
| 382 |
+
all_model_classes = (ClapTextModel, ClapTextModelWithProjection) if is_torch_available() else ()
|
| 383 |
+
fx_compatible = False
|
| 384 |
+
test_pruning = False
|
| 385 |
+
test_head_masking = False
|
| 386 |
+
|
| 387 |
+
def setUp(self):
|
| 388 |
+
self.model_tester = ClapTextModelTester(self)
|
| 389 |
+
self.config_tester = ConfigTester(self, config_class=ClapTextConfig, hidden_size=37)
|
| 390 |
+
|
| 391 |
+
def test_config(self):
|
| 392 |
+
self.config_tester.run_common_tests()
|
| 393 |
+
|
| 394 |
+
def test_model(self):
|
| 395 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 396 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 397 |
+
|
| 398 |
+
def test_model_with_projection(self):
|
| 399 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 400 |
+
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
| 401 |
+
|
| 402 |
+
@unittest.skip(reason="ClapTextModel does not output any loss term in the forward pass")
|
| 403 |
+
def test_training(self):
|
| 404 |
+
pass
|
| 405 |
+
|
| 406 |
+
@unittest.skip(reason="ClapTextModel does not output any loss term in the forward pass")
|
| 407 |
+
def test_training_gradient_checkpointing(self):
|
| 408 |
+
pass
|
| 409 |
+
|
| 410 |
+
@unittest.skip(
|
| 411 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 412 |
+
)
|
| 413 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 414 |
+
pass
|
| 415 |
+
|
| 416 |
+
@unittest.skip(
|
| 417 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 418 |
+
)
|
| 419 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 420 |
+
pass
|
| 421 |
+
|
| 422 |
+
@unittest.skip(reason="ClapTextModel does not use inputs_embeds")
|
| 423 |
+
def test_inputs_embeds(self):
|
| 424 |
+
pass
|
| 425 |
+
|
| 426 |
+
@slow
|
| 427 |
+
def test_model_from_pretrained(self):
|
| 428 |
+
model_name = "laion/clap-htsat-fused"
|
| 429 |
+
model = ClapTextModel.from_pretrained(model_name)
|
| 430 |
+
self.assertIsNotNone(model)
|
| 431 |
+
|
| 432 |
+
@slow
|
| 433 |
+
def test_model_with_projection_from_pretrained(self):
|
| 434 |
+
model_name = "laion/clap-htsat-fused"
|
| 435 |
+
model = ClapTextModelWithProjection.from_pretrained(model_name)
|
| 436 |
+
self.assertIsNotNone(model)
|
| 437 |
+
self.assertTrue(hasattr(model, "text_projection"))
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class ClapModelTester:
|
| 441 |
+
def __init__(self, parent, text_kwargs=None, audio_kwargs=None, is_training=True):
|
| 442 |
+
if text_kwargs is None:
|
| 443 |
+
text_kwargs = {}
|
| 444 |
+
if audio_kwargs is None:
|
| 445 |
+
audio_kwargs = {}
|
| 446 |
+
|
| 447 |
+
self.parent = parent
|
| 448 |
+
self.text_model_tester = ClapTextModelTester(parent, **text_kwargs)
|
| 449 |
+
self.audio_model_tester = ClapAudioModelTester(parent, **audio_kwargs)
|
| 450 |
+
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
| 451 |
+
self.is_training = is_training
|
| 452 |
+
|
| 453 |
+
def prepare_config_and_inputs(self):
|
| 454 |
+
_, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
| 455 |
+
_, input_features = self.audio_model_tester.prepare_config_and_inputs()
|
| 456 |
+
|
| 457 |
+
config = self.get_config()
|
| 458 |
+
|
| 459 |
+
return config, input_ids, attention_mask, input_features
|
| 460 |
+
|
| 461 |
+
def get_config(self):
|
| 462 |
+
return ClapConfig.from_text_audio_configs(
|
| 463 |
+
self.text_model_tester.get_config(), self.audio_model_tester.get_config(), projection_dim=64
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
def create_and_check_model(self, config, input_ids, attention_mask, input_features):
|
| 467 |
+
model = ClapModel(config).to(torch_device).eval()
|
| 468 |
+
with torch.no_grad():
|
| 469 |
+
result = model(input_ids, input_features, attention_mask)
|
| 470 |
+
self.parent.assertEqual(
|
| 471 |
+
result.logits_per_audio.shape, (self.audio_model_tester.batch_size, self.text_model_tester.batch_size)
|
| 472 |
+
)
|
| 473 |
+
self.parent.assertEqual(
|
| 474 |
+
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.audio_model_tester.batch_size)
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
def prepare_config_and_inputs_for_common(self):
|
| 478 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 479 |
+
config, input_ids, attention_mask, input_features = config_and_inputs
|
| 480 |
+
inputs_dict = {
|
| 481 |
+
"input_ids": input_ids,
|
| 482 |
+
"attention_mask": attention_mask,
|
| 483 |
+
"input_features": input_features,
|
| 484 |
+
"return_loss": True,
|
| 485 |
+
}
|
| 486 |
+
return config, inputs_dict
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
@require_torch
|
| 490 |
+
class ClapModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 491 |
+
all_model_classes = (ClapModel,) if is_torch_available() else ()
|
| 492 |
+
pipeline_model_mapping = {"feature-extraction": ClapModel} if is_torch_available() else {}
|
| 493 |
+
fx_compatible = False
|
| 494 |
+
test_head_masking = False
|
| 495 |
+
test_pruning = False
|
| 496 |
+
test_resize_embeddings = False
|
| 497 |
+
test_attention_outputs = False
|
| 498 |
+
|
| 499 |
+
def setUp(self):
|
| 500 |
+
self.model_tester = ClapModelTester(self)
|
| 501 |
+
common_properties = ["logit_scale_init_value", "projection_hidden_act", "projection_dim"]
|
| 502 |
+
self.config_tester = ConfigTester(
|
| 503 |
+
self, config_class=ClapConfig, has_text_modality=False, common_properties=common_properties
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def test_model(self):
|
| 507 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 508 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 509 |
+
|
| 510 |
+
def test_config(self):
|
| 511 |
+
self.config_tester.run_common_tests()
|
| 512 |
+
|
| 513 |
+
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
| 514 |
+
def test_hidden_states_output(self):
|
| 515 |
+
pass
|
| 516 |
+
|
| 517 |
+
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
| 518 |
+
def test_inputs_embeds(self):
|
| 519 |
+
pass
|
| 520 |
+
|
| 521 |
+
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
| 522 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 523 |
+
pass
|
| 524 |
+
|
| 525 |
+
@unittest.skip(reason="ClapModel does not have input/output embeddings")
|
| 526 |
+
def test_model_get_set_embeddings(self):
|
| 527 |
+
pass
|
| 528 |
+
|
| 529 |
+
# override as the `logit_scale` parameter initialization is different for CLAP
|
| 530 |
+
def test_initialization(self):
|
| 531 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 532 |
+
|
| 533 |
+
configs_no_init = _config_zero_init(config)
|
| 534 |
+
for model_class in self.all_model_classes:
|
| 535 |
+
model = model_class(config=configs_no_init)
|
| 536 |
+
for name, param in model.named_parameters():
|
| 537 |
+
if param.requires_grad:
|
| 538 |
+
# check if `logit_scale` is initialized as per the original implementation
|
| 539 |
+
if name == "logit_scale":
|
| 540 |
+
self.assertAlmostEqual(
|
| 541 |
+
param.data.item(),
|
| 542 |
+
np.log(1 / 0.07),
|
| 543 |
+
delta=1e-3,
|
| 544 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 545 |
+
)
|
| 546 |
+
else:
|
| 547 |
+
self.assertIn(
|
| 548 |
+
((param.data.mean() * 1e9).round() / 1e9).item(),
|
| 549 |
+
[0.0, 1.0],
|
| 550 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
def _create_and_check_torchscript(self, config, inputs_dict):
|
| 554 |
+
if not self.test_torchscript:
|
| 555 |
+
self.skipTest(reason="test_torchscript is set to False")
|
| 556 |
+
|
| 557 |
+
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
| 558 |
+
configs_no_init.torchscript = True
|
| 559 |
+
configs_no_init.return_dict = False
|
| 560 |
+
for model_class in self.all_model_classes:
|
| 561 |
+
model = model_class(config=configs_no_init)
|
| 562 |
+
model.to(torch_device)
|
| 563 |
+
model.eval()
|
| 564 |
+
|
| 565 |
+
try:
|
| 566 |
+
input_ids = inputs_dict["input_ids"]
|
| 567 |
+
input_features = inputs_dict["input_features"] # CLAP needs input_features
|
| 568 |
+
traced_model = torch.jit.trace(model, (input_ids, input_features))
|
| 569 |
+
except RuntimeError:
|
| 570 |
+
self.fail("Couldn't trace module.")
|
| 571 |
+
|
| 572 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 573 |
+
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
| 574 |
+
|
| 575 |
+
try:
|
| 576 |
+
torch.jit.save(traced_model, pt_file_name)
|
| 577 |
+
except Exception:
|
| 578 |
+
self.fail("Couldn't save module.")
|
| 579 |
+
|
| 580 |
+
try:
|
| 581 |
+
loaded_model = torch.jit.load(pt_file_name)
|
| 582 |
+
except Exception:
|
| 583 |
+
self.fail("Couldn't load module.")
|
| 584 |
+
|
| 585 |
+
model.to(torch_device)
|
| 586 |
+
model.eval()
|
| 587 |
+
|
| 588 |
+
loaded_model.to(torch_device)
|
| 589 |
+
loaded_model.eval()
|
| 590 |
+
|
| 591 |
+
model_state_dict = model.state_dict()
|
| 592 |
+
loaded_model_state_dict = loaded_model.state_dict()
|
| 593 |
+
|
| 594 |
+
non_persistent_buffers = {}
|
| 595 |
+
for key in loaded_model_state_dict.keys():
|
| 596 |
+
if key not in model_state_dict.keys():
|
| 597 |
+
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
| 598 |
+
|
| 599 |
+
loaded_model_state_dict = {
|
| 600 |
+
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
| 604 |
+
|
| 605 |
+
model_buffers = list(model.buffers())
|
| 606 |
+
for non_persistent_buffer in non_persistent_buffers.values():
|
| 607 |
+
found_buffer = False
|
| 608 |
+
for i, model_buffer in enumerate(model_buffers):
|
| 609 |
+
if torch.equal(non_persistent_buffer, model_buffer):
|
| 610 |
+
found_buffer = True
|
| 611 |
+
break
|
| 612 |
+
|
| 613 |
+
self.assertTrue(found_buffer)
|
| 614 |
+
model_buffers.pop(i)
|
| 615 |
+
|
| 616 |
+
models_equal = True
|
| 617 |
+
for layer_name, p1 in model_state_dict.items():
|
| 618 |
+
p2 = loaded_model_state_dict[layer_name]
|
| 619 |
+
if p1.data.ne(p2.data).sum() > 0:
|
| 620 |
+
models_equal = False
|
| 621 |
+
|
| 622 |
+
self.assertTrue(models_equal)
|
| 623 |
+
|
| 624 |
+
def test_load_audio_text_config(self):
|
| 625 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 626 |
+
|
| 627 |
+
# Save ClapConfig and check if we can load ClapAudioConfig from it
|
| 628 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 629 |
+
config.save_pretrained(tmp_dir_name)
|
| 630 |
+
audio_config = ClapAudioConfig.from_pretrained(tmp_dir_name)
|
| 631 |
+
self.assertDictEqual(config.audio_config.to_dict(), audio_config.to_dict())
|
| 632 |
+
|
| 633 |
+
# Save ClapConfig and check if we can load ClapTextConfig from it
|
| 634 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 635 |
+
config.save_pretrained(tmp_dir_name)
|
| 636 |
+
text_config = ClapTextConfig.from_pretrained(tmp_dir_name)
|
| 637 |
+
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
| 638 |
+
|
| 639 |
+
@slow
|
| 640 |
+
def test_model_from_pretrained(self):
|
| 641 |
+
model_name = "laion/clap-htsat-fused"
|
| 642 |
+
model = ClapModel.from_pretrained(model_name)
|
| 643 |
+
self.assertIsNotNone(model)
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@slow
|
| 647 |
+
@require_torch
|
| 648 |
+
class ClapModelIntegrationTest(unittest.TestCase):
|
| 649 |
+
paddings = ["repeatpad", "repeat", "pad"]
|
| 650 |
+
|
| 651 |
+
def test_integration_unfused(self):
|
| 652 |
+
EXPECTED_MEANS_UNFUSED = {
|
| 653 |
+
"repeatpad": 0.0024,
|
| 654 |
+
"pad": 0.0020,
|
| 655 |
+
"repeat": 0.0023,
|
| 656 |
+
}
|
| 657 |
+
|
| 658 |
+
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 659 |
+
audio_sample = librispeech_dummy[-1]
|
| 660 |
+
|
| 661 |
+
model_id = "laion/clap-htsat-unfused"
|
| 662 |
+
|
| 663 |
+
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
| 664 |
+
processor = ClapProcessor.from_pretrained(model_id)
|
| 665 |
+
|
| 666 |
+
for padding in self.paddings:
|
| 667 |
+
inputs = processor(audios=audio_sample["audio"]["array"], return_tensors="pt", padding=padding).to(
|
| 668 |
+
torch_device
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
audio_embed = model.get_audio_features(**inputs)
|
| 672 |
+
expected_mean = EXPECTED_MEANS_UNFUSED[padding]
|
| 673 |
+
|
| 674 |
+
self.assertTrue(
|
| 675 |
+
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
def test_integration_fused(self):
|
| 679 |
+
EXPECTED_MEANS_FUSED = {
|
| 680 |
+
"repeatpad": 0.00069,
|
| 681 |
+
"repeat": 0.00196,
|
| 682 |
+
"pad": -0.000379,
|
| 683 |
+
}
|
| 684 |
+
|
| 685 |
+
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 686 |
+
audio_sample = librispeech_dummy[-1]
|
| 687 |
+
|
| 688 |
+
model_id = "laion/clap-htsat-fused"
|
| 689 |
+
|
| 690 |
+
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
| 691 |
+
processor = ClapProcessor.from_pretrained(model_id)
|
| 692 |
+
|
| 693 |
+
for padding in self.paddings:
|
| 694 |
+
inputs = processor(
|
| 695 |
+
audios=audio_sample["audio"]["array"], return_tensors="pt", padding=padding, truncation="fusion"
|
| 696 |
+
).to(torch_device)
|
| 697 |
+
|
| 698 |
+
audio_embed = model.get_audio_features(**inputs)
|
| 699 |
+
expected_mean = EXPECTED_MEANS_FUSED[padding]
|
| 700 |
+
|
| 701 |
+
self.assertTrue(
|
| 702 |
+
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
def test_batched_fused(self):
|
| 706 |
+
EXPECTED_MEANS_FUSED = {
|
| 707 |
+
"repeatpad": 0.0010,
|
| 708 |
+
"repeat": 0.0020,
|
| 709 |
+
"pad": 0.0006,
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 713 |
+
audio_samples = [sample["array"] for sample in librispeech_dummy[0:4]["audio"]]
|
| 714 |
+
|
| 715 |
+
model_id = "laion/clap-htsat-fused"
|
| 716 |
+
|
| 717 |
+
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
| 718 |
+
processor = ClapProcessor.from_pretrained(model_id)
|
| 719 |
+
|
| 720 |
+
for padding in self.paddings:
|
| 721 |
+
inputs = processor(audios=audio_samples, return_tensors="pt", padding=padding, truncation="fusion").to(
|
| 722 |
+
torch_device
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
audio_embed = model.get_audio_features(**inputs)
|
| 726 |
+
expected_mean = EXPECTED_MEANS_FUSED[padding]
|
| 727 |
+
|
| 728 |
+
self.assertTrue(
|
| 729 |
+
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
| 730 |
+
)
|
| 731 |
+
|
| 732 |
+
def test_batched_unfused(self):
|
| 733 |
+
EXPECTED_MEANS_FUSED = {
|
| 734 |
+
"repeatpad": 0.0016,
|
| 735 |
+
"repeat": 0.0019,
|
| 736 |
+
"pad": 0.0019,
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 740 |
+
audio_samples = [sample["array"] for sample in librispeech_dummy[0:4]["audio"]]
|
| 741 |
+
|
| 742 |
+
model_id = "laion/clap-htsat-unfused"
|
| 743 |
+
|
| 744 |
+
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
| 745 |
+
processor = ClapProcessor.from_pretrained(model_id)
|
| 746 |
+
|
| 747 |
+
for padding in self.paddings:
|
| 748 |
+
inputs = processor(audios=audio_samples, return_tensors="pt", padding=padding).to(torch_device)
|
| 749 |
+
|
| 750 |
+
audio_embed = model.get_audio_features(**inputs)
|
| 751 |
+
expected_mean = EXPECTED_MEANS_FUSED[padding]
|
| 752 |
+
|
| 753 |
+
self.assertTrue(
|
| 754 |
+
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
| 755 |
+
)
|
docs/transformers/tests/models/clap/test_processor_clap.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import shutil
|
| 16 |
+
import tempfile
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from transformers import ClapFeatureExtractor, ClapProcessor, RobertaTokenizer, RobertaTokenizerFast
|
| 20 |
+
from transformers.testing_utils import require_sentencepiece, require_torchaudio
|
| 21 |
+
|
| 22 |
+
from .test_feature_extraction_clap import floats_list
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@require_torchaudio
|
| 26 |
+
@require_sentencepiece
|
| 27 |
+
class ClapProcessorTest(unittest.TestCase):
|
| 28 |
+
def setUp(self):
|
| 29 |
+
self.checkpoint = "laion/clap-htsat-unfused"
|
| 30 |
+
self.tmpdirname = tempfile.mkdtemp()
|
| 31 |
+
|
| 32 |
+
def get_tokenizer(self, **kwargs):
|
| 33 |
+
return RobertaTokenizer.from_pretrained(self.checkpoint, **kwargs)
|
| 34 |
+
|
| 35 |
+
def get_feature_extractor(self, **kwargs):
|
| 36 |
+
return ClapFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
| 37 |
+
|
| 38 |
+
def tearDown(self):
|
| 39 |
+
shutil.rmtree(self.tmpdirname)
|
| 40 |
+
|
| 41 |
+
def test_save_load_pretrained_default(self):
|
| 42 |
+
tokenizer = self.get_tokenizer()
|
| 43 |
+
feature_extractor = self.get_feature_extractor()
|
| 44 |
+
|
| 45 |
+
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 46 |
+
|
| 47 |
+
processor.save_pretrained(self.tmpdirname)
|
| 48 |
+
processor = ClapProcessor.from_pretrained(self.tmpdirname)
|
| 49 |
+
|
| 50 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
| 51 |
+
self.assertIsInstance(processor.tokenizer, RobertaTokenizerFast)
|
| 52 |
+
|
| 53 |
+
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
| 54 |
+
self.assertIsInstance(processor.feature_extractor, ClapFeatureExtractor)
|
| 55 |
+
|
| 56 |
+
def test_save_load_pretrained_additional_features(self):
|
| 57 |
+
processor = ClapProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
| 58 |
+
processor.save_pretrained(self.tmpdirname)
|
| 59 |
+
|
| 60 |
+
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
| 61 |
+
feature_extractor_add_kwargs = self.get_feature_extractor(do_normalize=False, padding_value=1.0)
|
| 62 |
+
|
| 63 |
+
processor = ClapProcessor.from_pretrained(
|
| 64 |
+
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
| 68 |
+
self.assertIsInstance(processor.tokenizer, RobertaTokenizerFast)
|
| 69 |
+
|
| 70 |
+
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
| 71 |
+
self.assertIsInstance(processor.feature_extractor, ClapFeatureExtractor)
|
| 72 |
+
|
| 73 |
+
def test_feature_extractor(self):
|
| 74 |
+
feature_extractor = self.get_feature_extractor()
|
| 75 |
+
tokenizer = self.get_tokenizer()
|
| 76 |
+
|
| 77 |
+
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 78 |
+
|
| 79 |
+
raw_speech = floats_list((3, 1000))
|
| 80 |
+
|
| 81 |
+
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
| 82 |
+
input_processor = processor(audios=raw_speech, return_tensors="np")
|
| 83 |
+
|
| 84 |
+
for key in input_feat_extract.keys():
|
| 85 |
+
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
| 86 |
+
|
| 87 |
+
def test_tokenizer(self):
|
| 88 |
+
feature_extractor = self.get_feature_extractor()
|
| 89 |
+
tokenizer = self.get_tokenizer()
|
| 90 |
+
|
| 91 |
+
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 92 |
+
|
| 93 |
+
input_str = "This is a test string"
|
| 94 |
+
|
| 95 |
+
encoded_processor = processor(text=input_str)
|
| 96 |
+
|
| 97 |
+
encoded_tok = tokenizer(input_str)
|
| 98 |
+
|
| 99 |
+
for key in encoded_tok.keys():
|
| 100 |
+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
| 101 |
+
|
| 102 |
+
def test_tokenizer_decode(self):
|
| 103 |
+
feature_extractor = self.get_feature_extractor()
|
| 104 |
+
tokenizer = self.get_tokenizer()
|
| 105 |
+
|
| 106 |
+
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 107 |
+
|
| 108 |
+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
| 109 |
+
|
| 110 |
+
decoded_processor = processor.batch_decode(predicted_ids)
|
| 111 |
+
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
| 112 |
+
|
| 113 |
+
self.assertListEqual(decoded_tok, decoded_processor)
|
| 114 |
+
|
| 115 |
+
def test_model_input_names(self):
|
| 116 |
+
feature_extractor = self.get_feature_extractor()
|
| 117 |
+
tokenizer = self.get_tokenizer()
|
| 118 |
+
|
| 119 |
+
processor = ClapProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 120 |
+
|
| 121 |
+
self.assertListEqual(
|
| 122 |
+
processor.model_input_names[2:],
|
| 123 |
+
feature_extractor.model_input_names,
|
| 124 |
+
msg="`processor` and `feature_extractor` model input names do not match",
|
| 125 |
+
)
|
docs/transformers/tests/models/clip/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/clip/test_image_processing_clip.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 HuggingFace Inc.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
from transformers.testing_utils import require_torch, require_vision
|
| 19 |
+
from transformers.utils import is_torchvision_available, is_vision_available
|
| 20 |
+
|
| 21 |
+
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if is_vision_available():
|
| 25 |
+
from transformers import CLIPImageProcessor
|
| 26 |
+
|
| 27 |
+
if is_torchvision_available():
|
| 28 |
+
from transformers import CLIPImageProcessorFast
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CLIPImageProcessingTester:
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
parent,
|
| 35 |
+
batch_size=7,
|
| 36 |
+
num_channels=3,
|
| 37 |
+
image_size=18,
|
| 38 |
+
min_resolution=30,
|
| 39 |
+
max_resolution=400,
|
| 40 |
+
do_resize=True,
|
| 41 |
+
size=None,
|
| 42 |
+
do_center_crop=True,
|
| 43 |
+
crop_size=None,
|
| 44 |
+
do_normalize=True,
|
| 45 |
+
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
| 46 |
+
image_std=[0.26862954, 0.26130258, 0.27577711],
|
| 47 |
+
do_convert_rgb=True,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
size = size if size is not None else {"shortest_edge": 20}
|
| 51 |
+
crop_size = crop_size if crop_size is not None else {"height": 18, "width": 18}
|
| 52 |
+
self.parent = parent
|
| 53 |
+
self.batch_size = batch_size
|
| 54 |
+
self.num_channels = num_channels
|
| 55 |
+
self.image_size = image_size
|
| 56 |
+
self.min_resolution = min_resolution
|
| 57 |
+
self.max_resolution = max_resolution
|
| 58 |
+
self.do_resize = do_resize
|
| 59 |
+
self.size = size
|
| 60 |
+
self.do_center_crop = do_center_crop
|
| 61 |
+
self.crop_size = crop_size
|
| 62 |
+
self.do_normalize = do_normalize
|
| 63 |
+
self.image_mean = image_mean
|
| 64 |
+
self.image_std = image_std
|
| 65 |
+
self.do_convert_rgb = do_convert_rgb
|
| 66 |
+
|
| 67 |
+
def prepare_image_processor_dict(self):
|
| 68 |
+
return {
|
| 69 |
+
"do_resize": self.do_resize,
|
| 70 |
+
"size": self.size,
|
| 71 |
+
"do_center_crop": self.do_center_crop,
|
| 72 |
+
"crop_size": self.crop_size,
|
| 73 |
+
"do_normalize": self.do_normalize,
|
| 74 |
+
"image_mean": self.image_mean,
|
| 75 |
+
"image_std": self.image_std,
|
| 76 |
+
"do_convert_rgb": self.do_convert_rgb,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def expected_output_image_shape(self, images):
|
| 80 |
+
return self.num_channels, self.crop_size["height"], self.crop_size["width"]
|
| 81 |
+
|
| 82 |
+
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
| 83 |
+
return prepare_image_inputs(
|
| 84 |
+
batch_size=self.batch_size,
|
| 85 |
+
num_channels=self.num_channels,
|
| 86 |
+
min_resolution=self.min_resolution,
|
| 87 |
+
max_resolution=self.max_resolution,
|
| 88 |
+
equal_resolution=equal_resolution,
|
| 89 |
+
numpify=numpify,
|
| 90 |
+
torchify=torchify,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@require_torch
|
| 95 |
+
@require_vision
|
| 96 |
+
class CLIPImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
| 97 |
+
image_processing_class = CLIPImageProcessor if is_vision_available() else None
|
| 98 |
+
fast_image_processing_class = CLIPImageProcessorFast if is_torchvision_available() else None
|
| 99 |
+
|
| 100 |
+
def setUp(self):
|
| 101 |
+
super().setUp()
|
| 102 |
+
self.image_processor_tester = CLIPImageProcessingTester(self)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def image_processor_dict(self):
|
| 106 |
+
return self.image_processor_tester.prepare_image_processor_dict()
|
| 107 |
+
|
| 108 |
+
def test_image_processor_properties(self):
|
| 109 |
+
for image_processing_class in self.image_processor_list:
|
| 110 |
+
image_processing = image_processing_class(**self.image_processor_dict)
|
| 111 |
+
self.assertTrue(hasattr(image_processing, "do_resize"))
|
| 112 |
+
self.assertTrue(hasattr(image_processing, "size"))
|
| 113 |
+
self.assertTrue(hasattr(image_processing, "do_center_crop"))
|
| 114 |
+
self.assertTrue(hasattr(image_processing, "center_crop"))
|
| 115 |
+
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
| 116 |
+
self.assertTrue(hasattr(image_processing, "image_mean"))
|
| 117 |
+
self.assertTrue(hasattr(image_processing, "image_std"))
|
| 118 |
+
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
|
| 119 |
+
|
| 120 |
+
def test_image_processor_from_dict_with_kwargs(self):
|
| 121 |
+
for image_processing_class in self.image_processor_list:
|
| 122 |
+
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
| 123 |
+
self.assertEqual(image_processor.size, {"shortest_edge": 20})
|
| 124 |
+
self.assertEqual(image_processor.crop_size, {"height": 18, "width": 18})
|
| 125 |
+
|
| 126 |
+
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
|
| 127 |
+
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
| 128 |
+
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
|
docs/transformers/tests/models/clip/test_modeling_clip.py
ADDED
|
@@ -0,0 +1,948 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch CLIP model."""
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
from parameterized import parameterized
|
| 24 |
+
from pytest import mark
|
| 25 |
+
|
| 26 |
+
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 27 |
+
from transformers.testing_utils import (
|
| 28 |
+
require_flash_attn,
|
| 29 |
+
require_torch,
|
| 30 |
+
require_torch_gpu,
|
| 31 |
+
require_torch_sdpa,
|
| 32 |
+
require_vision,
|
| 33 |
+
slow,
|
| 34 |
+
torch_device,
|
| 35 |
+
)
|
| 36 |
+
from transformers.utils import (
|
| 37 |
+
is_torch_available,
|
| 38 |
+
is_vision_available,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
from ...test_configuration_common import ConfigTester
|
| 42 |
+
from ...test_modeling_common import (
|
| 43 |
+
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
| 44 |
+
ModelTesterMixin,
|
| 45 |
+
_config_zero_init,
|
| 46 |
+
floats_tensor,
|
| 47 |
+
ids_tensor,
|
| 48 |
+
is_flaky,
|
| 49 |
+
random_attention_mask,
|
| 50 |
+
)
|
| 51 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if is_torch_available():
|
| 55 |
+
import torch
|
| 56 |
+
from torch import nn
|
| 57 |
+
|
| 58 |
+
from transformers import (
|
| 59 |
+
CLIPForImageClassification,
|
| 60 |
+
CLIPModel,
|
| 61 |
+
CLIPTextModel,
|
| 62 |
+
CLIPTextModelWithProjection,
|
| 63 |
+
CLIPVisionModel,
|
| 64 |
+
CLIPVisionModelWithProjection,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if is_vision_available():
|
| 68 |
+
from PIL import Image
|
| 69 |
+
|
| 70 |
+
from transformers import CLIPProcessor
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class CLIPVisionModelTester:
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
parent,
|
| 77 |
+
batch_size=12,
|
| 78 |
+
image_size=30,
|
| 79 |
+
patch_size=2,
|
| 80 |
+
num_channels=3,
|
| 81 |
+
is_training=True,
|
| 82 |
+
hidden_size=32,
|
| 83 |
+
projection_dim=32,
|
| 84 |
+
num_hidden_layers=2,
|
| 85 |
+
num_attention_heads=4,
|
| 86 |
+
intermediate_size=37,
|
| 87 |
+
dropout=0.1,
|
| 88 |
+
attention_dropout=0.1,
|
| 89 |
+
initializer_range=0.02,
|
| 90 |
+
scope=None,
|
| 91 |
+
):
|
| 92 |
+
self.parent = parent
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.image_size = image_size
|
| 95 |
+
self.patch_size = patch_size
|
| 96 |
+
self.num_channels = num_channels
|
| 97 |
+
self.is_training = is_training
|
| 98 |
+
self.hidden_size = hidden_size
|
| 99 |
+
self.projection_dim = projection_dim
|
| 100 |
+
self.num_hidden_layers = num_hidden_layers
|
| 101 |
+
self.num_attention_heads = num_attention_heads
|
| 102 |
+
self.intermediate_size = intermediate_size
|
| 103 |
+
self.dropout = dropout
|
| 104 |
+
self.attention_dropout = attention_dropout
|
| 105 |
+
self.initializer_range = initializer_range
|
| 106 |
+
self.scope = scope
|
| 107 |
+
|
| 108 |
+
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 109 |
+
num_patches = (image_size // patch_size) ** 2
|
| 110 |
+
self.seq_length = num_patches + 1
|
| 111 |
+
|
| 112 |
+
def prepare_config_and_inputs(self):
|
| 113 |
+
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
| 114 |
+
config = self.get_config()
|
| 115 |
+
|
| 116 |
+
return config, pixel_values
|
| 117 |
+
|
| 118 |
+
def get_config(self):
|
| 119 |
+
return CLIPVisionConfig(
|
| 120 |
+
image_size=self.image_size,
|
| 121 |
+
patch_size=self.patch_size,
|
| 122 |
+
num_channels=self.num_channels,
|
| 123 |
+
hidden_size=self.hidden_size,
|
| 124 |
+
projection_dim=self.projection_dim,
|
| 125 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 126 |
+
num_attention_heads=self.num_attention_heads,
|
| 127 |
+
intermediate_size=self.intermediate_size,
|
| 128 |
+
dropout=self.dropout,
|
| 129 |
+
attention_dropout=self.attention_dropout,
|
| 130 |
+
initializer_range=self.initializer_range,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def create_and_check_model(self, config, pixel_values):
|
| 134 |
+
model = CLIPVisionModel(config=config)
|
| 135 |
+
model.to(torch_device)
|
| 136 |
+
model.eval()
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
result = model(pixel_values)
|
| 139 |
+
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
| 140 |
+
image_size = (self.image_size, self.image_size)
|
| 141 |
+
patch_size = (self.patch_size, self.patch_size)
|
| 142 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 143 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
| 144 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 145 |
+
|
| 146 |
+
def create_and_check_model_with_projection(self, config, pixel_values):
|
| 147 |
+
model = CLIPVisionModelWithProjection(config=config)
|
| 148 |
+
model.to(torch_device)
|
| 149 |
+
model.eval()
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
result = model(pixel_values)
|
| 152 |
+
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
| 153 |
+
image_size = (self.image_size, self.image_size)
|
| 154 |
+
patch_size = (self.patch_size, self.patch_size)
|
| 155 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 156 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
| 157 |
+
self.parent.assertEqual(result.image_embeds.shape, (self.batch_size, self.projection_dim))
|
| 158 |
+
|
| 159 |
+
def prepare_config_and_inputs_for_common(self):
|
| 160 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 161 |
+
config, pixel_values = config_and_inputs
|
| 162 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 163 |
+
return config, inputs_dict
|
| 164 |
+
|
| 165 |
+
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
| 166 |
+
@require_torch_sdpa
|
| 167 |
+
def test_eager_matches_sdpa_inference(self, *args):
|
| 168 |
+
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class CLIPModelTesterMixin(ModelTesterMixin):
|
| 172 |
+
"""
|
| 173 |
+
Subclass of ModelTesterMixin with methods specific to testing CLIP models.
|
| 174 |
+
The SDPA equivalence test is overridden here because CLIP models may have test/vision/text+vision inputs,
|
| 175 |
+
different output logits, and are not supposed to be used or tested with padding_side="left".
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
@require_torch_sdpa
|
| 179 |
+
def test_sdpa_can_dispatch_composite_models(self):
|
| 180 |
+
for model_class in self.all_model_classes:
|
| 181 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 182 |
+
model = model_class(config)
|
| 183 |
+
|
| 184 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 185 |
+
model.save_pretrained(tmpdirname)
|
| 186 |
+
|
| 187 |
+
# Load the model with SDPA (it is the default, but we explicit it for clarity)
|
| 188 |
+
model_sdpa = model_class.from_pretrained(tmpdirname, attn_implementation="sdpa")
|
| 189 |
+
model_sdpa = model_sdpa.eval().to(torch_device)
|
| 190 |
+
|
| 191 |
+
# Load model with eager attention
|
| 192 |
+
model_eager = model_class.from_pretrained(
|
| 193 |
+
tmpdirname,
|
| 194 |
+
attn_implementation="eager",
|
| 195 |
+
)
|
| 196 |
+
model_eager = model_eager.eval().to(torch_device)
|
| 197 |
+
|
| 198 |
+
if hasattr(model_sdpa, "vision_model"):
|
| 199 |
+
self.assertTrue(model_sdpa.vision_model.config._attn_implementation == "sdpa")
|
| 200 |
+
self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager")
|
| 201 |
+
|
| 202 |
+
if hasattr(model_sdpa, "text_model"):
|
| 203 |
+
self.assertTrue(model_sdpa.text_model.config._attn_implementation == "sdpa")
|
| 204 |
+
self.assertTrue(model_eager.text_model.config._attn_implementation == "eager")
|
| 205 |
+
|
| 206 |
+
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
| 207 |
+
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@require_torch
|
| 211 |
+
class CLIPVisionModelTest(CLIPModelTesterMixin, unittest.TestCase):
|
| 212 |
+
"""
|
| 213 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
|
| 214 |
+
attention_mask and seq_length.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
all_model_classes = (CLIPVisionModel, CLIPVisionModelWithProjection) if is_torch_available() else ()
|
| 218 |
+
fx_compatible = True
|
| 219 |
+
test_pruning = False
|
| 220 |
+
test_resize_embeddings = False
|
| 221 |
+
test_head_masking = False
|
| 222 |
+
|
| 223 |
+
def setUp(self):
|
| 224 |
+
self.model_tester = CLIPVisionModelTester(self)
|
| 225 |
+
self.config_tester = ConfigTester(self, config_class=CLIPVisionConfig, has_text_modality=False, hidden_size=37)
|
| 226 |
+
|
| 227 |
+
def test_config(self):
|
| 228 |
+
self.config_tester.run_common_tests()
|
| 229 |
+
|
| 230 |
+
@unittest.skip(reason="CLIP does not use inputs_embeds")
|
| 231 |
+
def test_inputs_embeds(self):
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
def test_model_get_set_embeddings(self):
|
| 235 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 236 |
+
|
| 237 |
+
for model_class in self.all_model_classes:
|
| 238 |
+
model = model_class(config)
|
| 239 |
+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
| 240 |
+
x = model.get_output_embeddings()
|
| 241 |
+
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
| 242 |
+
|
| 243 |
+
def test_forward_signature(self):
|
| 244 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 245 |
+
|
| 246 |
+
for model_class in self.all_model_classes:
|
| 247 |
+
model = model_class(config)
|
| 248 |
+
signature = inspect.signature(model.forward)
|
| 249 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 250 |
+
arg_names = [*signature.parameters.keys()]
|
| 251 |
+
|
| 252 |
+
expected_arg_names = ["pixel_values"]
|
| 253 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 254 |
+
|
| 255 |
+
def test_model(self):
|
| 256 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 257 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 258 |
+
|
| 259 |
+
def test_model_with_projection(self):
|
| 260 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 261 |
+
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
| 262 |
+
|
| 263 |
+
@unittest.skip
|
| 264 |
+
def test_training(self):
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
@unittest.skip
|
| 268 |
+
def test_training_gradient_checkpointing(self):
|
| 269 |
+
pass
|
| 270 |
+
|
| 271 |
+
@unittest.skip(
|
| 272 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 273 |
+
)
|
| 274 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 275 |
+
pass
|
| 276 |
+
|
| 277 |
+
@unittest.skip(
|
| 278 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 279 |
+
)
|
| 280 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 281 |
+
pass
|
| 282 |
+
|
| 283 |
+
@slow
|
| 284 |
+
def test_model_from_pretrained(self):
|
| 285 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 286 |
+
model = CLIPVisionModel.from_pretrained(model_name)
|
| 287 |
+
self.assertIsNotNone(model)
|
| 288 |
+
|
| 289 |
+
@slow
|
| 290 |
+
def test_model_with_projection_from_pretrained(self):
|
| 291 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 292 |
+
model = CLIPVisionModelWithProjection.from_pretrained(model_name)
|
| 293 |
+
self.assertIsNotNone(model)
|
| 294 |
+
self.assertTrue(hasattr(model, "visual_projection"))
|
| 295 |
+
|
| 296 |
+
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
| 297 |
+
@require_torch_sdpa
|
| 298 |
+
@is_flaky()
|
| 299 |
+
def test_eager_matches_sdpa_inference(self, *args):
|
| 300 |
+
# adding only flaky decorator here and call the parent test method
|
| 301 |
+
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
| 302 |
+
|
| 303 |
+
@require_torch_sdpa
|
| 304 |
+
def test_sdpa_can_dispatch_composite_models(self):
|
| 305 |
+
super().test_sdpa_can_dispatch_composite_models()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class CLIPTextModelTester:
|
| 309 |
+
def __init__(
|
| 310 |
+
self,
|
| 311 |
+
parent,
|
| 312 |
+
batch_size=12,
|
| 313 |
+
seq_length=7,
|
| 314 |
+
is_training=True,
|
| 315 |
+
use_input_mask=True,
|
| 316 |
+
use_labels=True,
|
| 317 |
+
vocab_size=99,
|
| 318 |
+
hidden_size=32,
|
| 319 |
+
projection_dim=32,
|
| 320 |
+
num_hidden_layers=2,
|
| 321 |
+
num_attention_heads=4,
|
| 322 |
+
intermediate_size=37,
|
| 323 |
+
dropout=0.1,
|
| 324 |
+
attention_dropout=0.1,
|
| 325 |
+
max_position_embeddings=512,
|
| 326 |
+
initializer_range=0.02,
|
| 327 |
+
scope=None,
|
| 328 |
+
):
|
| 329 |
+
self.parent = parent
|
| 330 |
+
self.batch_size = batch_size
|
| 331 |
+
self.seq_length = seq_length
|
| 332 |
+
self.is_training = is_training
|
| 333 |
+
self.use_input_mask = use_input_mask
|
| 334 |
+
self.use_labels = use_labels
|
| 335 |
+
self.vocab_size = vocab_size
|
| 336 |
+
self.hidden_size = hidden_size
|
| 337 |
+
self.projection_dim = projection_dim
|
| 338 |
+
self.num_hidden_layers = num_hidden_layers
|
| 339 |
+
self.num_attention_heads = num_attention_heads
|
| 340 |
+
self.intermediate_size = intermediate_size
|
| 341 |
+
self.dropout = dropout
|
| 342 |
+
self.attention_dropout = attention_dropout
|
| 343 |
+
self.max_position_embeddings = max_position_embeddings
|
| 344 |
+
self.initializer_range = initializer_range
|
| 345 |
+
self.scope = scope
|
| 346 |
+
|
| 347 |
+
def prepare_config_and_inputs(self):
|
| 348 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 349 |
+
|
| 350 |
+
input_mask = None
|
| 351 |
+
if self.use_input_mask:
|
| 352 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 353 |
+
|
| 354 |
+
if input_mask is not None:
|
| 355 |
+
batch_size, seq_length = input_mask.shape
|
| 356 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 357 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 358 |
+
input_mask[batch_idx, :start_index] = 1
|
| 359 |
+
input_mask[batch_idx, start_index:] = 0
|
| 360 |
+
|
| 361 |
+
config = self.get_config()
|
| 362 |
+
|
| 363 |
+
return config, input_ids, input_mask
|
| 364 |
+
|
| 365 |
+
def get_config(self):
|
| 366 |
+
return CLIPTextConfig(
|
| 367 |
+
vocab_size=self.vocab_size,
|
| 368 |
+
hidden_size=self.hidden_size,
|
| 369 |
+
projection_dim=self.projection_dim,
|
| 370 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 371 |
+
num_attention_heads=self.num_attention_heads,
|
| 372 |
+
intermediate_size=self.intermediate_size,
|
| 373 |
+
dropout=self.dropout,
|
| 374 |
+
attention_dropout=self.attention_dropout,
|
| 375 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 376 |
+
initializer_range=self.initializer_range,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def create_and_check_model(self, config, input_ids, input_mask):
|
| 380 |
+
model = CLIPTextModel(config=config)
|
| 381 |
+
model.to(torch_device)
|
| 382 |
+
model.eval()
|
| 383 |
+
with torch.no_grad():
|
| 384 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 385 |
+
result = model(input_ids)
|
| 386 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 387 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 388 |
+
|
| 389 |
+
def create_and_check_model_with_projection(self, config, input_ids, input_mask):
|
| 390 |
+
model = CLIPTextModelWithProjection(config=config)
|
| 391 |
+
model.to(torch_device)
|
| 392 |
+
model.eval()
|
| 393 |
+
with torch.no_grad():
|
| 394 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 395 |
+
result = model(input_ids)
|
| 396 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 397 |
+
self.parent.assertEqual(result.text_embeds.shape, (self.batch_size, self.projection_dim))
|
| 398 |
+
|
| 399 |
+
def prepare_config_and_inputs_for_common(self):
|
| 400 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 401 |
+
config, input_ids, input_mask = config_and_inputs
|
| 402 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 403 |
+
return config, inputs_dict
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
@require_torch
|
| 407 |
+
class CLIPTextModelTest(CLIPModelTesterMixin, unittest.TestCase):
|
| 408 |
+
all_model_classes = (CLIPTextModel, CLIPTextModelWithProjection) if is_torch_available() else ()
|
| 409 |
+
fx_compatible = True
|
| 410 |
+
test_pruning = False
|
| 411 |
+
test_head_masking = False
|
| 412 |
+
model_split_percents = [0.5, 0.8, 0.9]
|
| 413 |
+
|
| 414 |
+
def setUp(self):
|
| 415 |
+
self.model_tester = CLIPTextModelTester(self)
|
| 416 |
+
self.config_tester = ConfigTester(self, config_class=CLIPTextConfig, hidden_size=37)
|
| 417 |
+
|
| 418 |
+
def test_config(self):
|
| 419 |
+
self.config_tester.run_common_tests()
|
| 420 |
+
|
| 421 |
+
def test_model(self):
|
| 422 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 423 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 424 |
+
|
| 425 |
+
def test_model_with_projection(self):
|
| 426 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 427 |
+
self.model_tester.create_and_check_model_with_projection(*config_and_inputs)
|
| 428 |
+
|
| 429 |
+
@unittest.skip
|
| 430 |
+
def test_training(self):
|
| 431 |
+
pass
|
| 432 |
+
|
| 433 |
+
@unittest.skip
|
| 434 |
+
def test_training_gradient_checkpointing(self):
|
| 435 |
+
pass
|
| 436 |
+
|
| 437 |
+
@unittest.skip(
|
| 438 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 439 |
+
)
|
| 440 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 441 |
+
pass
|
| 442 |
+
|
| 443 |
+
@unittest.skip(
|
| 444 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 445 |
+
)
|
| 446 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 447 |
+
pass
|
| 448 |
+
|
| 449 |
+
@unittest.skip(reason="CLIP does not use inputs_embeds")
|
| 450 |
+
def test_inputs_embeds(self):
|
| 451 |
+
pass
|
| 452 |
+
|
| 453 |
+
@slow
|
| 454 |
+
def test_model_from_pretrained(self):
|
| 455 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 456 |
+
model = CLIPTextModel.from_pretrained(model_name)
|
| 457 |
+
self.assertIsNotNone(model)
|
| 458 |
+
|
| 459 |
+
@slow
|
| 460 |
+
def test_model_with_projection_from_pretrained(self):
|
| 461 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 462 |
+
model = CLIPTextModelWithProjection.from_pretrained(model_name)
|
| 463 |
+
self.assertIsNotNone(model)
|
| 464 |
+
self.assertTrue(hasattr(model, "text_projection"))
|
| 465 |
+
|
| 466 |
+
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
| 467 |
+
@require_torch_sdpa
|
| 468 |
+
@slow
|
| 469 |
+
@is_flaky()
|
| 470 |
+
def test_eager_matches_sdpa_inference(self, *args):
|
| 471 |
+
# adding only flaky decorator here and call the parent test method
|
| 472 |
+
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
| 473 |
+
|
| 474 |
+
@require_torch_sdpa
|
| 475 |
+
def test_sdpa_can_dispatch_composite_models(self):
|
| 476 |
+
super().test_sdpa_can_dispatch_composite_models()
|
| 477 |
+
|
| 478 |
+
@require_torch_sdpa
|
| 479 |
+
def test_sdpa_can_dispatch_on_flash(self):
|
| 480 |
+
self.skipTest(reason="CLIPTextModel has two attention masks: `causal_attention_mask` and `attention_mask`")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class CLIPModelTester:
|
| 484 |
+
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
| 485 |
+
if text_kwargs is None:
|
| 486 |
+
text_kwargs = {}
|
| 487 |
+
if vision_kwargs is None:
|
| 488 |
+
vision_kwargs = {}
|
| 489 |
+
|
| 490 |
+
self.parent = parent
|
| 491 |
+
self.text_model_tester = CLIPTextModelTester(parent, **text_kwargs)
|
| 492 |
+
self.vision_model_tester = CLIPVisionModelTester(parent, **vision_kwargs)
|
| 493 |
+
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
| 494 |
+
self.is_training = is_training
|
| 495 |
+
|
| 496 |
+
def prepare_config_and_inputs(self):
|
| 497 |
+
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
| 498 |
+
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 499 |
+
|
| 500 |
+
config = self.get_config()
|
| 501 |
+
|
| 502 |
+
return config, input_ids, attention_mask, pixel_values
|
| 503 |
+
|
| 504 |
+
def get_config(self):
|
| 505 |
+
return CLIPConfig.from_text_vision_configs(
|
| 506 |
+
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
| 510 |
+
model = CLIPModel(config).to(torch_device).eval()
|
| 511 |
+
with torch.no_grad():
|
| 512 |
+
result = model(input_ids, pixel_values, attention_mask)
|
| 513 |
+
self.parent.assertEqual(
|
| 514 |
+
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
| 515 |
+
)
|
| 516 |
+
self.parent.assertEqual(
|
| 517 |
+
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
def prepare_config_and_inputs_for_common(self):
|
| 521 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 522 |
+
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
| 523 |
+
inputs_dict = {
|
| 524 |
+
"input_ids": input_ids,
|
| 525 |
+
"attention_mask": attention_mask,
|
| 526 |
+
"pixel_values": pixel_values,
|
| 527 |
+
"return_loss": True,
|
| 528 |
+
}
|
| 529 |
+
return config, inputs_dict
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
@require_torch
|
| 533 |
+
class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 534 |
+
all_model_classes = (CLIPModel,) if is_torch_available() else ()
|
| 535 |
+
pipeline_model_mapping = (
|
| 536 |
+
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
|
| 537 |
+
)
|
| 538 |
+
fx_compatible = True
|
| 539 |
+
test_head_masking = False
|
| 540 |
+
test_pruning = False
|
| 541 |
+
test_resize_embeddings = False
|
| 542 |
+
test_attention_outputs = False
|
| 543 |
+
_is_composite = True
|
| 544 |
+
|
| 545 |
+
def setUp(self):
|
| 546 |
+
self.model_tester = CLIPModelTester(self)
|
| 547 |
+
common_properties = ["projection_dim", "logit_scale_init_value"]
|
| 548 |
+
self.config_tester = ConfigTester(
|
| 549 |
+
self, config_class=CLIPConfig, has_text_modality=False, common_properties=common_properties
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
def test_model(self):
|
| 553 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 554 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 555 |
+
|
| 556 |
+
def test_config(self):
|
| 557 |
+
self.config_tester.run_common_tests()
|
| 558 |
+
|
| 559 |
+
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
| 560 |
+
def test_hidden_states_output(self):
|
| 561 |
+
pass
|
| 562 |
+
|
| 563 |
+
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
| 564 |
+
def test_inputs_embeds(self):
|
| 565 |
+
pass
|
| 566 |
+
|
| 567 |
+
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
| 568 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 569 |
+
pass
|
| 570 |
+
|
| 571 |
+
@unittest.skip(reason="CLIPModel does not have input/output embeddings")
|
| 572 |
+
def test_model_get_set_embeddings(self):
|
| 573 |
+
pass
|
| 574 |
+
|
| 575 |
+
# override as the `logit_scale` parameter initialization is different for CLIP
|
| 576 |
+
def test_initialization(self):
|
| 577 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 578 |
+
|
| 579 |
+
configs_no_init = _config_zero_init(config)
|
| 580 |
+
for model_class in self.all_model_classes:
|
| 581 |
+
model = model_class(config=configs_no_init)
|
| 582 |
+
for name, param in model.named_parameters():
|
| 583 |
+
if param.requires_grad:
|
| 584 |
+
# check if `logit_scale` is initialized as per the original implementation
|
| 585 |
+
if name == "logit_scale":
|
| 586 |
+
self.assertAlmostEqual(
|
| 587 |
+
param.data.item(),
|
| 588 |
+
np.log(1 / 0.07),
|
| 589 |
+
delta=1e-3,
|
| 590 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 591 |
+
)
|
| 592 |
+
else:
|
| 593 |
+
self.assertIn(
|
| 594 |
+
((param.data.mean() * 1e9).round() / 1e9).item(),
|
| 595 |
+
[0.0, 1.0],
|
| 596 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
def _create_and_check_torchscript(self, config, inputs_dict):
|
| 600 |
+
if not self.test_torchscript:
|
| 601 |
+
self.skipTest(reason="test_torchscript is set to False")
|
| 602 |
+
|
| 603 |
+
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
| 604 |
+
configs_no_init.torchscript = True
|
| 605 |
+
configs_no_init.return_dict = False
|
| 606 |
+
for model_class in self.all_model_classes:
|
| 607 |
+
model = model_class(config=configs_no_init)
|
| 608 |
+
model.to(torch_device)
|
| 609 |
+
model.eval()
|
| 610 |
+
|
| 611 |
+
try:
|
| 612 |
+
input_ids = inputs_dict["input_ids"]
|
| 613 |
+
pixel_values = inputs_dict["pixel_values"] # CLIP needs pixel_values
|
| 614 |
+
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
| 615 |
+
except RuntimeError:
|
| 616 |
+
self.fail("Couldn't trace module.")
|
| 617 |
+
|
| 618 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 619 |
+
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
| 620 |
+
|
| 621 |
+
try:
|
| 622 |
+
torch.jit.save(traced_model, pt_file_name)
|
| 623 |
+
except Exception:
|
| 624 |
+
self.fail("Couldn't save module.")
|
| 625 |
+
|
| 626 |
+
try:
|
| 627 |
+
loaded_model = torch.jit.load(pt_file_name)
|
| 628 |
+
except Exception:
|
| 629 |
+
self.fail("Couldn't load module.")
|
| 630 |
+
|
| 631 |
+
model.to(torch_device)
|
| 632 |
+
model.eval()
|
| 633 |
+
|
| 634 |
+
loaded_model.to(torch_device)
|
| 635 |
+
loaded_model.eval()
|
| 636 |
+
|
| 637 |
+
model_state_dict = model.state_dict()
|
| 638 |
+
loaded_model_state_dict = loaded_model.state_dict()
|
| 639 |
+
|
| 640 |
+
non_persistent_buffers = {}
|
| 641 |
+
for key in loaded_model_state_dict.keys():
|
| 642 |
+
if key not in model_state_dict.keys():
|
| 643 |
+
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
| 644 |
+
|
| 645 |
+
loaded_model_state_dict = {
|
| 646 |
+
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
| 647 |
+
}
|
| 648 |
+
|
| 649 |
+
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
| 650 |
+
|
| 651 |
+
model_buffers = list(model.buffers())
|
| 652 |
+
for non_persistent_buffer in non_persistent_buffers.values():
|
| 653 |
+
found_buffer = False
|
| 654 |
+
for i, model_buffer in enumerate(model_buffers):
|
| 655 |
+
if torch.equal(non_persistent_buffer, model_buffer):
|
| 656 |
+
found_buffer = True
|
| 657 |
+
break
|
| 658 |
+
|
| 659 |
+
self.assertTrue(found_buffer)
|
| 660 |
+
model_buffers.pop(i)
|
| 661 |
+
|
| 662 |
+
models_equal = True
|
| 663 |
+
for layer_name, p1 in model_state_dict.items():
|
| 664 |
+
p2 = loaded_model_state_dict[layer_name]
|
| 665 |
+
if p1.data.ne(p2.data).sum() > 0:
|
| 666 |
+
models_equal = False
|
| 667 |
+
|
| 668 |
+
self.assertTrue(models_equal)
|
| 669 |
+
|
| 670 |
+
def test_load_vision_text_config(self):
|
| 671 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 672 |
+
|
| 673 |
+
# Save CLIPConfig and check if we can load CLIPVisionConfig from it
|
| 674 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 675 |
+
config.save_pretrained(tmp_dir_name)
|
| 676 |
+
vision_config = CLIPVisionConfig.from_pretrained(tmp_dir_name)
|
| 677 |
+
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
| 678 |
+
|
| 679 |
+
# Save CLIPConfig and check if we can load CLIPTextConfig from it
|
| 680 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 681 |
+
config.save_pretrained(tmp_dir_name)
|
| 682 |
+
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
|
| 683 |
+
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
| 684 |
+
|
| 685 |
+
@slow
|
| 686 |
+
def test_model_from_pretrained(self):
|
| 687 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 688 |
+
model = CLIPModel.from_pretrained(model_name)
|
| 689 |
+
self.assertIsNotNone(model)
|
| 690 |
+
|
| 691 |
+
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
| 692 |
+
@require_torch_sdpa
|
| 693 |
+
@slow
|
| 694 |
+
@is_flaky()
|
| 695 |
+
def test_eager_matches_sdpa_inference(self, *args):
|
| 696 |
+
# adding only flaky decorator here and call the parent test method
|
| 697 |
+
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
| 698 |
+
|
| 699 |
+
@require_torch_sdpa
|
| 700 |
+
def test_sdpa_can_dispatch_composite_models(self):
|
| 701 |
+
super().test_sdpa_can_dispatch_composite_models()
|
| 702 |
+
|
| 703 |
+
@require_torch_sdpa
|
| 704 |
+
def test_sdpa_can_dispatch_on_flash(self):
|
| 705 |
+
self.skipTest(reason="CLIP text tower has two attention masks: `causal_attention_mask` and `attention_mask`")
|
| 706 |
+
|
| 707 |
+
@require_torch_sdpa
|
| 708 |
+
def test_sdpa_can_compile_dynamic(self):
|
| 709 |
+
self.skipTest(reason="CLIP model can't be compiled dynamic, error in clip_loss`")
|
| 710 |
+
|
| 711 |
+
@require_flash_attn
|
| 712 |
+
@require_torch_gpu
|
| 713 |
+
@mark.flash_attn_test
|
| 714 |
+
@slow
|
| 715 |
+
def test_flash_attn_2_inference_equivalence(self):
|
| 716 |
+
for model_class in self.all_model_classes:
|
| 717 |
+
if not model_class._supports_flash_attn_2:
|
| 718 |
+
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
| 719 |
+
|
| 720 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 721 |
+
model = model_class(config)
|
| 722 |
+
|
| 723 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 724 |
+
model.save_pretrained(tmpdirname)
|
| 725 |
+
model_fa = model_class.from_pretrained(
|
| 726 |
+
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
| 727 |
+
)
|
| 728 |
+
model_fa.to(torch_device)
|
| 729 |
+
|
| 730 |
+
model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
|
| 731 |
+
model.to(torch_device)
|
| 732 |
+
|
| 733 |
+
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
|
| 734 |
+
dummy_input_ids = inputs_dict["input_ids"]
|
| 735 |
+
|
| 736 |
+
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
|
| 737 |
+
outputs_fa = model_fa(
|
| 738 |
+
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
self.assertTrue(
|
| 742 |
+
torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
|
| 743 |
+
f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
|
| 744 |
+
)
|
| 745 |
+
self.assertTrue(
|
| 746 |
+
torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
|
| 747 |
+
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
|
| 748 |
+
)
|
| 749 |
+
|
| 750 |
+
@require_flash_attn
|
| 751 |
+
@require_torch_gpu
|
| 752 |
+
@mark.flash_attn_test
|
| 753 |
+
def test_flash_attn_2_inference_equivalence_right_padding(self):
|
| 754 |
+
for model_class in self.all_model_classes:
|
| 755 |
+
if not model_class._supports_flash_attn_2:
|
| 756 |
+
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
|
| 757 |
+
|
| 758 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 759 |
+
model = model_class(config)
|
| 760 |
+
|
| 761 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 762 |
+
model.save_pretrained(tmpdirname)
|
| 763 |
+
model_fa = model_class.from_pretrained(
|
| 764 |
+
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
|
| 765 |
+
)
|
| 766 |
+
model_fa.to(torch_device)
|
| 767 |
+
|
| 768 |
+
model = model_class.from_pretrained(
|
| 769 |
+
tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
| 770 |
+
)
|
| 771 |
+
model.to(torch_device)
|
| 772 |
+
|
| 773 |
+
dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
|
| 774 |
+
dummy_input_ids = inputs_dict["input_ids"]
|
| 775 |
+
dummy_pixel_mask = inputs_dict["attention_mask"]
|
| 776 |
+
|
| 777 |
+
# right padding
|
| 778 |
+
dummy_pixel_mask[:] = 1
|
| 779 |
+
dummy_pixel_mask[:, -1:] = 0
|
| 780 |
+
|
| 781 |
+
outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
|
| 782 |
+
outputs_fa = model_fa(
|
| 783 |
+
pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
logits_per_image_eager = outputs.logits_per_image[:, :-1]
|
| 787 |
+
logits_per_text_eager = outputs.logits_per_text[:, :-1]
|
| 788 |
+
|
| 789 |
+
logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
|
| 790 |
+
logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
|
| 791 |
+
|
| 792 |
+
self.assertTrue(
|
| 793 |
+
torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
|
| 794 |
+
f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
|
| 795 |
+
)
|
| 796 |
+
self.assertTrue(
|
| 797 |
+
torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
|
| 798 |
+
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
class CLIPForImageClassificationModelTester(CLIPModelTester):
|
| 803 |
+
def __init__(self, parent):
|
| 804 |
+
super().__init__(parent)
|
| 805 |
+
self.batch_size = self.vision_model_tester.batch_size
|
| 806 |
+
self.num_hidden_layers = self.vision_model_tester.num_hidden_layers
|
| 807 |
+
self.hidden_size = self.vision_model_tester.hidden_size
|
| 808 |
+
self.seq_length = self.vision_model_tester.seq_length
|
| 809 |
+
|
| 810 |
+
def prepare_config_and_inputs(self):
|
| 811 |
+
_, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 812 |
+
config = self.get_config()
|
| 813 |
+
|
| 814 |
+
return config, pixel_values
|
| 815 |
+
|
| 816 |
+
def prepare_config_and_inputs_for_common(self):
|
| 817 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 818 |
+
config, pixel_values = config_and_inputs
|
| 819 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 820 |
+
return config, inputs_dict
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
@require_torch
|
| 824 |
+
class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 825 |
+
all_model_classes = (CLIPForImageClassification,) if is_torch_available() else ()
|
| 826 |
+
pipeline_model_mapping = {"image-classification": CLIPForImageClassification} if is_torch_available() else {}
|
| 827 |
+
fx_compatible = False
|
| 828 |
+
test_head_masking = False
|
| 829 |
+
test_pruning = False
|
| 830 |
+
test_resize_embeddings = False
|
| 831 |
+
test_attention_outputs = False
|
| 832 |
+
_is_composite = True
|
| 833 |
+
|
| 834 |
+
def setUp(self):
|
| 835 |
+
self.model_tester = CLIPForImageClassificationModelTester(self)
|
| 836 |
+
|
| 837 |
+
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
| 838 |
+
def test_inputs_embeds(self):
|
| 839 |
+
pass
|
| 840 |
+
|
| 841 |
+
@unittest.skip(reason="CLIPForImageClassification does not support inputs_embeds")
|
| 842 |
+
def test_model_get_set_embeddings(self):
|
| 843 |
+
pass
|
| 844 |
+
|
| 845 |
+
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
| 846 |
+
def test_training_gradient_checkpointing(self):
|
| 847 |
+
pass
|
| 848 |
+
|
| 849 |
+
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
| 850 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 851 |
+
pass
|
| 852 |
+
|
| 853 |
+
@unittest.skip(reason="CLIPForImageClassification does not support gradient checkpointing yet")
|
| 854 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 855 |
+
pass
|
| 856 |
+
|
| 857 |
+
@unittest.skip(reason="CLIP uses the same initialization scheme as the Flax original implementation")
|
| 858 |
+
def test_initialization(self):
|
| 859 |
+
pass
|
| 860 |
+
|
| 861 |
+
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
| 862 |
+
@require_torch_sdpa
|
| 863 |
+
@slow
|
| 864 |
+
@is_flaky()
|
| 865 |
+
def test_eager_matches_sdpa_inference(self, *args):
|
| 866 |
+
# adding only flaky decorator here and call the parent test method
|
| 867 |
+
return getattr(ModelTesterMixin, self._testMethodName)(self)
|
| 868 |
+
|
| 869 |
+
@require_torch_sdpa
|
| 870 |
+
def test_sdpa_can_dispatch_composite_models(self):
|
| 871 |
+
super().test_sdpa_can_dispatch_composite_models()
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
# We will verify our results on an image of cute cats
|
| 875 |
+
def prepare_img():
|
| 876 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 877 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 878 |
+
return im
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
@require_vision
|
| 882 |
+
@require_torch
|
| 883 |
+
class CLIPModelIntegrationTest(unittest.TestCase):
|
| 884 |
+
@slow
|
| 885 |
+
def test_inference(self):
|
| 886 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 887 |
+
model = CLIPModel.from_pretrained(model_name, attn_implementation="sdpa").to(torch_device)
|
| 888 |
+
processor = CLIPProcessor.from_pretrained(model_name)
|
| 889 |
+
|
| 890 |
+
image = prepare_img()
|
| 891 |
+
inputs = processor(
|
| 892 |
+
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
|
| 893 |
+
).to(torch_device)
|
| 894 |
+
|
| 895 |
+
# forward pass
|
| 896 |
+
with torch.no_grad():
|
| 897 |
+
outputs = model(**inputs)
|
| 898 |
+
|
| 899 |
+
# verify the logits
|
| 900 |
+
self.assertEqual(
|
| 901 |
+
outputs.logits_per_image.shape,
|
| 902 |
+
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
| 903 |
+
)
|
| 904 |
+
self.assertEqual(
|
| 905 |
+
outputs.logits_per_text.shape,
|
| 906 |
+
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
|
| 910 |
+
|
| 911 |
+
torch.testing.assert_close(outputs.logits_per_image, expected_logits, rtol=1e-3, atol=1e-3)
|
| 912 |
+
|
| 913 |
+
@slow
|
| 914 |
+
def test_inference_interpolate_pos_encoding(self):
|
| 915 |
+
# CLIP models have an `interpolate_pos_encoding` argument in their forward method,
|
| 916 |
+
# allowing to interpolate the pre-trained position embeddings in order to use
|
| 917 |
+
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
| 918 |
+
# to visualize self-attention on higher resolution images.
|
| 919 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device)
|
| 920 |
+
|
| 921 |
+
processor = CLIPProcessor.from_pretrained(
|
| 922 |
+
"openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
| 926 |
+
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
| 927 |
+
|
| 928 |
+
# interpolate_pos_encodiung false should return value error
|
| 929 |
+
with self.assertRaises(ValueError, msg="doesn't match model"):
|
| 930 |
+
with torch.no_grad():
|
| 931 |
+
model(**inputs, interpolate_pos_encoding=False)
|
| 932 |
+
|
| 933 |
+
# forward pass
|
| 934 |
+
with torch.no_grad():
|
| 935 |
+
outputs = model(**inputs, interpolate_pos_encoding=True)
|
| 936 |
+
|
| 937 |
+
# verify the logits
|
| 938 |
+
expected_shape = torch.Size((1, 26, 768))
|
| 939 |
+
|
| 940 |
+
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
| 941 |
+
|
| 942 |
+
expected_slice = torch.tensor(
|
| 943 |
+
[[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]]
|
| 944 |
+
).to(torch_device)
|
| 945 |
+
|
| 946 |
+
torch.testing.assert_close(
|
| 947 |
+
outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=6e-3, atol=4e-4
|
| 948 |
+
)
|
docs/transformers/tests/models/clip/test_modeling_flax_clip.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import tempfile
|
| 3 |
+
import unittest
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available
|
| 8 |
+
from transformers.testing_utils import require_flax, slow
|
| 9 |
+
|
| 10 |
+
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
if is_flax_available():
|
| 14 |
+
import jax
|
| 15 |
+
|
| 16 |
+
from transformers.models.clip.modeling_flax_clip import (
|
| 17 |
+
FlaxCLIPModel,
|
| 18 |
+
FlaxCLIPTextModel,
|
| 19 |
+
FlaxCLIPTextModelWithProjection,
|
| 20 |
+
FlaxCLIPVisionModel,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FlaxCLIPVisionModelTester:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
parent,
|
| 28 |
+
batch_size=12,
|
| 29 |
+
image_size=30,
|
| 30 |
+
patch_size=2,
|
| 31 |
+
num_channels=3,
|
| 32 |
+
is_training=True,
|
| 33 |
+
hidden_size=32,
|
| 34 |
+
num_hidden_layers=2,
|
| 35 |
+
num_attention_heads=4,
|
| 36 |
+
intermediate_size=37,
|
| 37 |
+
dropout=0.1,
|
| 38 |
+
attention_dropout=0.1,
|
| 39 |
+
initializer_range=0.02,
|
| 40 |
+
scope=None,
|
| 41 |
+
):
|
| 42 |
+
self.parent = parent
|
| 43 |
+
self.batch_size = batch_size
|
| 44 |
+
self.image_size = image_size
|
| 45 |
+
self.patch_size = patch_size
|
| 46 |
+
self.num_channels = num_channels
|
| 47 |
+
self.is_training = is_training
|
| 48 |
+
self.hidden_size = hidden_size
|
| 49 |
+
self.num_hidden_layers = num_hidden_layers
|
| 50 |
+
self.num_attention_heads = num_attention_heads
|
| 51 |
+
self.intermediate_size = intermediate_size
|
| 52 |
+
self.dropout = dropout
|
| 53 |
+
self.attention_dropout = attention_dropout
|
| 54 |
+
self.initializer_range = initializer_range
|
| 55 |
+
self.scope = scope
|
| 56 |
+
|
| 57 |
+
def prepare_config_and_inputs(self):
|
| 58 |
+
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
| 59 |
+
config = CLIPVisionConfig(
|
| 60 |
+
image_size=self.image_size,
|
| 61 |
+
patch_size=self.patch_size,
|
| 62 |
+
num_channels=self.num_channels,
|
| 63 |
+
hidden_size=self.hidden_size,
|
| 64 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 65 |
+
num_attention_heads=self.num_attention_heads,
|
| 66 |
+
intermediate_size=self.intermediate_size,
|
| 67 |
+
dropout=self.dropout,
|
| 68 |
+
attention_dropout=self.attention_dropout,
|
| 69 |
+
initializer_range=self.initializer_range,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return config, pixel_values
|
| 73 |
+
|
| 74 |
+
def prepare_config_and_inputs_for_common(self):
|
| 75 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 76 |
+
config, pixel_values = config_and_inputs
|
| 77 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 78 |
+
return config, inputs_dict
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@require_flax
|
| 82 |
+
class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
| 83 |
+
"""
|
| 84 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
|
| 85 |
+
attention_mask and seq_length.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
all_model_classes = (FlaxCLIPVisionModel,) if is_flax_available() else ()
|
| 89 |
+
|
| 90 |
+
def setUp(self):
|
| 91 |
+
self.model_tester = FlaxCLIPVisionModelTester(self)
|
| 92 |
+
|
| 93 |
+
def test_forward_signature(self):
|
| 94 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 95 |
+
|
| 96 |
+
for model_class in self.all_model_classes:
|
| 97 |
+
model = model_class(config)
|
| 98 |
+
signature = inspect.signature(model.__call__)
|
| 99 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 100 |
+
arg_names = [*signature.parameters.keys()]
|
| 101 |
+
|
| 102 |
+
expected_arg_names = ["pixel_values"]
|
| 103 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 104 |
+
|
| 105 |
+
def test_jit_compilation(self):
|
| 106 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 107 |
+
|
| 108 |
+
for model_class in self.all_model_classes:
|
| 109 |
+
with self.subTest(model_class.__name__):
|
| 110 |
+
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
| 111 |
+
model = model_class(config)
|
| 112 |
+
|
| 113 |
+
@jax.jit
|
| 114 |
+
def model_jitted(pixel_values, **kwargs):
|
| 115 |
+
return model(pixel_values=pixel_values, **kwargs).to_tuple()
|
| 116 |
+
|
| 117 |
+
with self.subTest("JIT Enabled"):
|
| 118 |
+
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
| 119 |
+
|
| 120 |
+
with self.subTest("JIT Disabled"):
|
| 121 |
+
with jax.disable_jit():
|
| 122 |
+
outputs = model_jitted(**prepared_inputs_dict)
|
| 123 |
+
|
| 124 |
+
self.assertEqual(len(outputs), len(jitted_outputs))
|
| 125 |
+
for jitted_output, output in zip(jitted_outputs, outputs):
|
| 126 |
+
self.assertEqual(jitted_output.shape, output.shape)
|
| 127 |
+
|
| 128 |
+
def test_hidden_states_output(self):
|
| 129 |
+
def check_hidden_states_output(inputs_dict, config, model_class):
|
| 130 |
+
model = model_class(config)
|
| 131 |
+
|
| 132 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 133 |
+
hidden_states = outputs.hidden_states
|
| 134 |
+
|
| 135 |
+
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
| 136 |
+
|
| 137 |
+
# CLIP has a different seq_length
|
| 138 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 139 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 140 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 141 |
+
seq_length = num_patches + 1
|
| 142 |
+
|
| 143 |
+
self.assertListEqual(
|
| 144 |
+
list(hidden_states[0].shape[-2:]),
|
| 145 |
+
[seq_length, self.model_tester.hidden_size],
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 149 |
+
|
| 150 |
+
for model_class in self.all_model_classes:
|
| 151 |
+
inputs_dict["output_hidden_states"] = True
|
| 152 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 153 |
+
|
| 154 |
+
# check that output_hidden_states also work using config
|
| 155 |
+
del inputs_dict["output_hidden_states"]
|
| 156 |
+
config.output_hidden_states = True
|
| 157 |
+
|
| 158 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 159 |
+
|
| 160 |
+
def test_attention_outputs(self):
|
| 161 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 162 |
+
config.return_dict = True
|
| 163 |
+
|
| 164 |
+
# in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 165 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 166 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 167 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 168 |
+
seq_length = num_patches + 1
|
| 169 |
+
|
| 170 |
+
for model_class in self.all_model_classes:
|
| 171 |
+
inputs_dict["output_attentions"] = True
|
| 172 |
+
inputs_dict["output_hidden_states"] = False
|
| 173 |
+
model = model_class(config)
|
| 174 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 175 |
+
attentions = outputs.attentions
|
| 176 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
| 177 |
+
|
| 178 |
+
# check that output_attentions also work using config
|
| 179 |
+
del inputs_dict["output_attentions"]
|
| 180 |
+
config.output_attentions = True
|
| 181 |
+
model = model_class(config)
|
| 182 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 183 |
+
attentions = outputs.attentions
|
| 184 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
| 185 |
+
|
| 186 |
+
self.assertListEqual(
|
| 187 |
+
list(attentions[0].shape[-3:]),
|
| 188 |
+
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
| 189 |
+
)
|
| 190 |
+
out_len = len(outputs)
|
| 191 |
+
|
| 192 |
+
# Check attention is always last and order is fine
|
| 193 |
+
inputs_dict["output_attentions"] = True
|
| 194 |
+
inputs_dict["output_hidden_states"] = True
|
| 195 |
+
model = model_class(config)
|
| 196 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 197 |
+
|
| 198 |
+
added_hidden_states = 1
|
| 199 |
+
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
| 200 |
+
|
| 201 |
+
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
| 202 |
+
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
| 203 |
+
|
| 204 |
+
self.assertListEqual(
|
| 205 |
+
list(self_attentions[0].shape[-3:]),
|
| 206 |
+
[self.model_tester.num_attention_heads, seq_length, seq_length],
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# FlaxCLIPVisionModel does not have any base model
|
| 210 |
+
def test_save_load_from_base(self):
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
# FlaxCLIPVisionModel does not have any base model
|
| 214 |
+
def test_save_load_to_base(self):
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
@slow
|
| 218 |
+
def test_model_from_pretrained(self):
|
| 219 |
+
for model_class_name in self.all_model_classes:
|
| 220 |
+
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
|
| 221 |
+
outputs = model(np.ones((1, 3, 224, 224)))
|
| 222 |
+
self.assertIsNotNone(outputs)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class FlaxCLIPTextModelTester:
|
| 226 |
+
def __init__(
|
| 227 |
+
self,
|
| 228 |
+
parent,
|
| 229 |
+
batch_size=12,
|
| 230 |
+
seq_length=7,
|
| 231 |
+
is_training=True,
|
| 232 |
+
use_input_mask=True,
|
| 233 |
+
use_labels=True,
|
| 234 |
+
vocab_size=99,
|
| 235 |
+
hidden_size=32,
|
| 236 |
+
num_hidden_layers=2,
|
| 237 |
+
num_attention_heads=4,
|
| 238 |
+
intermediate_size=37,
|
| 239 |
+
dropout=0.1,
|
| 240 |
+
attention_dropout=0.1,
|
| 241 |
+
max_position_embeddings=512,
|
| 242 |
+
initializer_range=0.02,
|
| 243 |
+
scope=None,
|
| 244 |
+
):
|
| 245 |
+
self.parent = parent
|
| 246 |
+
self.batch_size = batch_size
|
| 247 |
+
self.seq_length = seq_length
|
| 248 |
+
self.is_training = is_training
|
| 249 |
+
self.use_input_mask = use_input_mask
|
| 250 |
+
self.use_labels = use_labels
|
| 251 |
+
self.vocab_size = vocab_size
|
| 252 |
+
self.hidden_size = hidden_size
|
| 253 |
+
self.num_hidden_layers = num_hidden_layers
|
| 254 |
+
self.num_attention_heads = num_attention_heads
|
| 255 |
+
self.intermediate_size = intermediate_size
|
| 256 |
+
self.dropout = dropout
|
| 257 |
+
self.attention_dropout = attention_dropout
|
| 258 |
+
self.max_position_embeddings = max_position_embeddings
|
| 259 |
+
self.initializer_range = initializer_range
|
| 260 |
+
self.scope = scope
|
| 261 |
+
|
| 262 |
+
def prepare_config_and_inputs(self):
|
| 263 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 264 |
+
|
| 265 |
+
input_mask = None
|
| 266 |
+
if self.use_input_mask:
|
| 267 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 268 |
+
|
| 269 |
+
if input_mask is not None:
|
| 270 |
+
batch_size, seq_length = input_mask.shape
|
| 271 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 272 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 273 |
+
input_mask[batch_idx, :start_index] = 1
|
| 274 |
+
input_mask[batch_idx, start_index:] = 0
|
| 275 |
+
|
| 276 |
+
config = CLIPTextConfig(
|
| 277 |
+
vocab_size=self.vocab_size,
|
| 278 |
+
hidden_size=self.hidden_size,
|
| 279 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 280 |
+
num_attention_heads=self.num_attention_heads,
|
| 281 |
+
intermediate_size=self.intermediate_size,
|
| 282 |
+
dropout=self.dropout,
|
| 283 |
+
attention_dropout=self.attention_dropout,
|
| 284 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 285 |
+
initializer_range=self.initializer_range,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
return config, input_ids, input_mask
|
| 289 |
+
|
| 290 |
+
def prepare_config_and_inputs_for_common(self):
|
| 291 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 292 |
+
config, input_ids, input_mask = config_and_inputs
|
| 293 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 294 |
+
return config, inputs_dict
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@require_flax
|
| 298 |
+
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
| 299 |
+
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()
|
| 300 |
+
|
| 301 |
+
def setUp(self):
|
| 302 |
+
self.model_tester = FlaxCLIPTextModelTester(self)
|
| 303 |
+
|
| 304 |
+
# FlaxCLIPTextModel does not have any base model
|
| 305 |
+
def test_save_load_from_base(self):
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
# FlaxCLIPVisionModel does not have any base model
|
| 309 |
+
def test_save_load_to_base(self):
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
@slow
|
| 313 |
+
def test_model_from_pretrained(self):
|
| 314 |
+
for model_class_name in self.all_model_classes:
|
| 315 |
+
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
|
| 316 |
+
outputs = model(np.ones((1, 1)))
|
| 317 |
+
self.assertIsNotNone(outputs)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class FlaxCLIPModelTester:
|
| 321 |
+
def __init__(self, parent, is_training=True):
|
| 322 |
+
self.parent = parent
|
| 323 |
+
self.text_model_tester = FlaxCLIPTextModelTester(parent)
|
| 324 |
+
self.vision_model_tester = FlaxCLIPVisionModelTester(parent)
|
| 325 |
+
self.is_training = is_training
|
| 326 |
+
|
| 327 |
+
def prepare_config_and_inputs(self):
|
| 328 |
+
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
| 329 |
+
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 330 |
+
|
| 331 |
+
config = CLIPConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)
|
| 332 |
+
|
| 333 |
+
return config, input_ids, attention_mask, pixel_values
|
| 334 |
+
|
| 335 |
+
def prepare_config_and_inputs_for_common(self):
|
| 336 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 337 |
+
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
| 338 |
+
inputs_dict = {
|
| 339 |
+
"input_ids": input_ids,
|
| 340 |
+
"attention_mask": attention_mask,
|
| 341 |
+
"pixel_values": pixel_values,
|
| 342 |
+
}
|
| 343 |
+
return config, inputs_dict
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
@require_flax
|
| 347 |
+
class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
| 348 |
+
all_model_classes = (FlaxCLIPModel,) if is_flax_available() else ()
|
| 349 |
+
test_attention_outputs = False
|
| 350 |
+
|
| 351 |
+
def setUp(self):
|
| 352 |
+
self.model_tester = FlaxCLIPModelTester(self)
|
| 353 |
+
|
| 354 |
+
# hidden_states are tested in individual model tests
|
| 355 |
+
def test_hidden_states_output(self):
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
def test_jit_compilation(self):
|
| 359 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 360 |
+
|
| 361 |
+
for model_class in self.all_model_classes:
|
| 362 |
+
with self.subTest(model_class.__name__):
|
| 363 |
+
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
| 364 |
+
model = model_class(config)
|
| 365 |
+
|
| 366 |
+
@jax.jit
|
| 367 |
+
def model_jitted(input_ids, pixel_values, **kwargs):
|
| 368 |
+
return model(input_ids=input_ids, pixel_values=pixel_values, **kwargs).to_tuple()
|
| 369 |
+
|
| 370 |
+
with self.subTest("JIT Enabled"):
|
| 371 |
+
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
| 372 |
+
|
| 373 |
+
with self.subTest("JIT Disabled"):
|
| 374 |
+
with jax.disable_jit():
|
| 375 |
+
outputs = model_jitted(**prepared_inputs_dict)
|
| 376 |
+
|
| 377 |
+
self.assertEqual(len(outputs), len(jitted_outputs))
|
| 378 |
+
for jitted_output, output in zip(jitted_outputs[:4], outputs[:4]):
|
| 379 |
+
self.assertEqual(jitted_output.shape, output.shape)
|
| 380 |
+
|
| 381 |
+
def test_forward_signature(self):
|
| 382 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 383 |
+
|
| 384 |
+
for model_class in self.all_model_classes:
|
| 385 |
+
model = model_class(config)
|
| 386 |
+
signature = inspect.signature(model.__call__)
|
| 387 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 388 |
+
arg_names = [*signature.parameters.keys()]
|
| 389 |
+
|
| 390 |
+
expected_arg_names = ["input_ids", "pixel_values", "attention_mask", "position_ids"]
|
| 391 |
+
self.assertListEqual(arg_names[:4], expected_arg_names)
|
| 392 |
+
|
| 393 |
+
def test_get_image_features(self):
|
| 394 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 395 |
+
model = FlaxCLIPModel(config)
|
| 396 |
+
|
| 397 |
+
@jax.jit
|
| 398 |
+
def model_jitted(pixel_values):
|
| 399 |
+
return model.get_image_features(pixel_values=pixel_values)
|
| 400 |
+
|
| 401 |
+
with self.subTest("JIT Enabled"):
|
| 402 |
+
jitted_output = model_jitted(inputs_dict["pixel_values"])
|
| 403 |
+
|
| 404 |
+
with self.subTest("JIT Disabled"):
|
| 405 |
+
with jax.disable_jit():
|
| 406 |
+
output = model_jitted(inputs_dict["pixel_values"])
|
| 407 |
+
|
| 408 |
+
self.assertEqual(jitted_output.shape, output.shape)
|
| 409 |
+
self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
|
| 410 |
+
|
| 411 |
+
def test_get_text_features(self):
|
| 412 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 413 |
+
model = FlaxCLIPModel(config)
|
| 414 |
+
|
| 415 |
+
@jax.jit
|
| 416 |
+
def model_jitted(input_ids, attention_mask, **kwargs):
|
| 417 |
+
return model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
|
| 418 |
+
|
| 419 |
+
with self.subTest("JIT Enabled"):
|
| 420 |
+
jitted_output = model_jitted(**inputs_dict)
|
| 421 |
+
|
| 422 |
+
with self.subTest("JIT Disabled"):
|
| 423 |
+
with jax.disable_jit():
|
| 424 |
+
output = model_jitted(**inputs_dict)
|
| 425 |
+
|
| 426 |
+
self.assertEqual(jitted_output.shape, output.shape)
|
| 427 |
+
self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
|
| 428 |
+
|
| 429 |
+
@slow
|
| 430 |
+
def test_model_from_pretrained(self):
|
| 431 |
+
for model_class_name in self.all_model_classes:
|
| 432 |
+
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
|
| 433 |
+
outputs = model(input_ids=np.ones((1, 1)), pixel_values=np.ones((1, 3, 224, 224)))
|
| 434 |
+
self.assertIsNotNone(outputs)
|
| 435 |
+
|
| 436 |
+
# overwrite from common since FlaxCLIPModel returns nested output
|
| 437 |
+
# which is not supported in the common test
|
| 438 |
+
def test_from_pretrained_save_pretrained(self):
|
| 439 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 440 |
+
|
| 441 |
+
for model_class in self.all_model_classes:
|
| 442 |
+
if model_class.__name__ != "FlaxBertModel":
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
with self.subTest(model_class.__name__):
|
| 446 |
+
model = model_class(config)
|
| 447 |
+
|
| 448 |
+
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
| 449 |
+
outputs = model(**prepared_inputs_dict).to_tuple()
|
| 450 |
+
|
| 451 |
+
# verify that normal save_pretrained works as expected
|
| 452 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 453 |
+
model.save_pretrained(tmpdirname)
|
| 454 |
+
model_loaded = model_class.from_pretrained(tmpdirname)
|
| 455 |
+
|
| 456 |
+
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4]
|
| 457 |
+
for output_loaded, output in zip(outputs_loaded, outputs):
|
| 458 |
+
self.assert_almost_equals(output_loaded, output, 1e-3)
|
| 459 |
+
|
| 460 |
+
# verify that save_pretrained for distributed training
|
| 461 |
+
# with `params=params` works as expected
|
| 462 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 463 |
+
model.save_pretrained(tmpdirname, params=model.params)
|
| 464 |
+
model_loaded = model_class.from_pretrained(tmpdirname)
|
| 465 |
+
|
| 466 |
+
outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()[:4]
|
| 467 |
+
for output_loaded, output in zip(outputs_loaded, outputs):
|
| 468 |
+
self.assert_almost_equals(output_loaded, output, 1e-3)
|
docs/transformers/tests/models/clip/test_modeling_tf_clip.py
ADDED
|
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the TensorFlow CLIP model."""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import inspect
|
| 19 |
+
import os
|
| 20 |
+
import tempfile
|
| 21 |
+
import unittest
|
| 22 |
+
from importlib import import_module
|
| 23 |
+
|
| 24 |
+
import requests
|
| 25 |
+
|
| 26 |
+
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 27 |
+
from transformers.testing_utils import require_tf, require_vision, slow
|
| 28 |
+
from transformers.utils import is_tf_available, is_vision_available
|
| 29 |
+
|
| 30 |
+
from ...test_configuration_common import ConfigTester
|
| 31 |
+
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
| 32 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if is_tf_available():
|
| 36 |
+
import tensorflow as tf
|
| 37 |
+
|
| 38 |
+
from transformers import TFCLIPModel, TFCLIPTextModel, TFCLIPVisionModel, TFSharedEmbeddings
|
| 39 |
+
from transformers.modeling_tf_utils import keras
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_vision_available():
|
| 43 |
+
from PIL import Image
|
| 44 |
+
|
| 45 |
+
from transformers import CLIPProcessor
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TFCLIPVisionModelTester:
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
parent,
|
| 52 |
+
batch_size=12,
|
| 53 |
+
image_size=30,
|
| 54 |
+
patch_size=2,
|
| 55 |
+
num_channels=3,
|
| 56 |
+
is_training=True,
|
| 57 |
+
hidden_size=32,
|
| 58 |
+
num_hidden_layers=2,
|
| 59 |
+
num_attention_heads=4,
|
| 60 |
+
intermediate_size=37,
|
| 61 |
+
dropout=0.1,
|
| 62 |
+
attention_dropout=0.1,
|
| 63 |
+
initializer_range=0.02,
|
| 64 |
+
scope=None,
|
| 65 |
+
):
|
| 66 |
+
self.parent = parent
|
| 67 |
+
self.batch_size = batch_size
|
| 68 |
+
self.image_size = image_size
|
| 69 |
+
self.patch_size = patch_size
|
| 70 |
+
self.num_channels = num_channels
|
| 71 |
+
self.is_training = is_training
|
| 72 |
+
self.hidden_size = hidden_size
|
| 73 |
+
self.num_hidden_layers = num_hidden_layers
|
| 74 |
+
self.num_attention_heads = num_attention_heads
|
| 75 |
+
self.intermediate_size = intermediate_size
|
| 76 |
+
self.dropout = dropout
|
| 77 |
+
self.attention_dropout = attention_dropout
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.scope = scope
|
| 80 |
+
|
| 81 |
+
def prepare_config_and_inputs(self):
|
| 82 |
+
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
| 83 |
+
config = self.get_config()
|
| 84 |
+
|
| 85 |
+
return config, pixel_values
|
| 86 |
+
|
| 87 |
+
def get_config(self):
|
| 88 |
+
return CLIPVisionConfig(
|
| 89 |
+
image_size=self.image_size,
|
| 90 |
+
patch_size=self.patch_size,
|
| 91 |
+
num_channels=self.num_channels,
|
| 92 |
+
hidden_size=self.hidden_size,
|
| 93 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 94 |
+
num_attention_heads=self.num_attention_heads,
|
| 95 |
+
intermediate_size=self.intermediate_size,
|
| 96 |
+
dropout=self.dropout,
|
| 97 |
+
attention_dropout=self.attention_dropout,
|
| 98 |
+
initializer_range=self.initializer_range,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def create_and_check_model(self, config, pixel_values):
|
| 102 |
+
model = TFCLIPVisionModel(config=config)
|
| 103 |
+
result = model(pixel_values, training=False)
|
| 104 |
+
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
| 105 |
+
image_size = (self.image_size, self.image_size)
|
| 106 |
+
patch_size = (self.patch_size, self.patch_size)
|
| 107 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 108 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
| 109 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 110 |
+
|
| 111 |
+
def prepare_config_and_inputs_for_common(self):
|
| 112 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 113 |
+
config, pixel_values = config_and_inputs
|
| 114 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 115 |
+
return config, inputs_dict
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@require_tf
|
| 119 |
+
class TFCLIPVisionModelTest(TFModelTesterMixin, unittest.TestCase):
|
| 120 |
+
"""
|
| 121 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
|
| 122 |
+
attention_mask and seq_length.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
all_model_classes = (TFCLIPVisionModel,) if is_tf_available() else ()
|
| 126 |
+
|
| 127 |
+
test_pruning = False
|
| 128 |
+
test_resize_embeddings = False
|
| 129 |
+
test_head_masking = False
|
| 130 |
+
test_onnx = False
|
| 131 |
+
|
| 132 |
+
def setUp(self):
|
| 133 |
+
self.model_tester = TFCLIPVisionModelTester(self)
|
| 134 |
+
self.config_tester = ConfigTester(self, config_class=CLIPVisionConfig, has_text_modality=False, hidden_size=37)
|
| 135 |
+
|
| 136 |
+
def test_config(self):
|
| 137 |
+
self.config_tester.run_common_tests()
|
| 138 |
+
|
| 139 |
+
def test_inputs_embeds(self):
|
| 140 |
+
# CLIP does not use inputs_embeds
|
| 141 |
+
pass
|
| 142 |
+
|
| 143 |
+
def test_graph_mode_with_inputs_embeds(self):
|
| 144 |
+
# CLIP does not use inputs_embeds
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
def test_model_common_attributes(self):
|
| 148 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 149 |
+
|
| 150 |
+
for model_class in self.all_model_classes:
|
| 151 |
+
model = model_class(config)
|
| 152 |
+
self.assertIsInstance(model.get_input_embeddings(), (keras.layers.Layer))
|
| 153 |
+
x = model.get_output_embeddings()
|
| 154 |
+
self.assertTrue(x is None or isinstance(x, keras.layers.Layer))
|
| 155 |
+
|
| 156 |
+
def test_forward_signature(self):
|
| 157 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 158 |
+
|
| 159 |
+
for model_class in self.all_model_classes:
|
| 160 |
+
model = model_class(config)
|
| 161 |
+
signature = inspect.signature(model.call)
|
| 162 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 163 |
+
arg_names = [*signature.parameters.keys()]
|
| 164 |
+
|
| 165 |
+
expected_arg_names = ["pixel_values"]
|
| 166 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 167 |
+
|
| 168 |
+
def test_model(self):
|
| 169 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 170 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 171 |
+
|
| 172 |
+
def test_attention_outputs(self):
|
| 173 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 174 |
+
config.return_dict = True
|
| 175 |
+
|
| 176 |
+
# in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 177 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 178 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 179 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 180 |
+
seq_len = num_patches + 1
|
| 181 |
+
|
| 182 |
+
for model_class in self.all_model_classes:
|
| 183 |
+
inputs_dict["output_attentions"] = True
|
| 184 |
+
inputs_dict["output_hidden_states"] = False
|
| 185 |
+
config.return_dict = True
|
| 186 |
+
model = model_class(config)
|
| 187 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
| 188 |
+
attentions = outputs.attentions
|
| 189 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
| 190 |
+
|
| 191 |
+
# check that output_attentions also work using config
|
| 192 |
+
del inputs_dict["output_attentions"]
|
| 193 |
+
config.output_attentions = True
|
| 194 |
+
model = model_class(config)
|
| 195 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
| 196 |
+
attentions = outputs.attentions
|
| 197 |
+
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
| 198 |
+
|
| 199 |
+
out_len = len(outputs)
|
| 200 |
+
|
| 201 |
+
# Check attention is always last and order is fine
|
| 202 |
+
inputs_dict["output_attentions"] = True
|
| 203 |
+
inputs_dict["output_hidden_states"] = True
|
| 204 |
+
model = model_class(config)
|
| 205 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
| 206 |
+
|
| 207 |
+
added_hidden_states = 1
|
| 208 |
+
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
| 209 |
+
|
| 210 |
+
self_attentions = outputs.attentions
|
| 211 |
+
|
| 212 |
+
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
| 213 |
+
|
| 214 |
+
self.assertListEqual(
|
| 215 |
+
list(self_attentions[0].shape[-3:]),
|
| 216 |
+
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def test_hidden_states_output(self):
|
| 220 |
+
def check_hidden_states_output(inputs_dict, config, model_class):
|
| 221 |
+
model = model_class(config)
|
| 222 |
+
|
| 223 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
|
| 224 |
+
|
| 225 |
+
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
| 226 |
+
|
| 227 |
+
expected_num_layers = getattr(
|
| 228 |
+
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
| 229 |
+
)
|
| 230 |
+
self.assertEqual(len(hidden_states), expected_num_layers)
|
| 231 |
+
|
| 232 |
+
# CLIP has a different seq_length
|
| 233 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 234 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 235 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 236 |
+
seq_length = num_patches + 1
|
| 237 |
+
|
| 238 |
+
self.assertListEqual(
|
| 239 |
+
list(hidden_states[0].shape[-2:]),
|
| 240 |
+
[seq_length, self.model_tester.hidden_size],
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 244 |
+
|
| 245 |
+
for model_class in self.all_model_classes:
|
| 246 |
+
inputs_dict["output_hidden_states"] = True
|
| 247 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 248 |
+
|
| 249 |
+
# check that output_hidden_states also work using config
|
| 250 |
+
del inputs_dict["output_hidden_states"]
|
| 251 |
+
config.output_hidden_states = True
|
| 252 |
+
|
| 253 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 254 |
+
|
| 255 |
+
@slow
|
| 256 |
+
def test_model_from_pretrained(self):
|
| 257 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 258 |
+
model = TFCLIPVisionModel.from_pretrained(model_name)
|
| 259 |
+
self.assertIsNotNone(model)
|
| 260 |
+
|
| 261 |
+
@slow
|
| 262 |
+
def test_saved_model_creation_extended(self):
|
| 263 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 264 |
+
config.output_hidden_states = True
|
| 265 |
+
config.output_attentions = True
|
| 266 |
+
|
| 267 |
+
if hasattr(config, "use_cache"):
|
| 268 |
+
config.use_cache = True
|
| 269 |
+
|
| 270 |
+
# in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 271 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 272 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 273 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 274 |
+
seq_len = num_patches + 1
|
| 275 |
+
|
| 276 |
+
for model_class in self.all_model_classes:
|
| 277 |
+
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
| 278 |
+
model = model_class(config)
|
| 279 |
+
num_out = len(model(class_inputs_dict))
|
| 280 |
+
|
| 281 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 282 |
+
model.save_pretrained(tmpdirname, saved_model=True)
|
| 283 |
+
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
| 284 |
+
model = keras.models.load_model(saved_model_dir)
|
| 285 |
+
outputs = model(class_inputs_dict)
|
| 286 |
+
output_hidden_states = outputs["hidden_states"]
|
| 287 |
+
output_attentions = outputs["attentions"]
|
| 288 |
+
|
| 289 |
+
# Check num outputs
|
| 290 |
+
self.assertEqual(len(outputs), num_out)
|
| 291 |
+
|
| 292 |
+
# Check num layers
|
| 293 |
+
expected_num_layers = getattr(
|
| 294 |
+
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.assertEqual(len(output_hidden_states), expected_num_layers)
|
| 298 |
+
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
|
| 299 |
+
|
| 300 |
+
# Check attention outputs
|
| 301 |
+
image_size = (self.model_tester.image_size, self.model_tester.image_size)
|
| 302 |
+
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
|
| 303 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 304 |
+
seq_len = num_patches + 1
|
| 305 |
+
|
| 306 |
+
self.assertListEqual(
|
| 307 |
+
list(output_attentions[0].shape[-3:]),
|
| 308 |
+
[self.model_tester.num_attention_heads, seq_len, seq_len],
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# Check hidden states
|
| 312 |
+
self.assertListEqual(
|
| 313 |
+
list(output_hidden_states[0].shape[-2:]),
|
| 314 |
+
[seq_len, self.model_tester.hidden_size],
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class TFCLIPTextModelTester:
|
| 319 |
+
def __init__(
|
| 320 |
+
self,
|
| 321 |
+
parent,
|
| 322 |
+
batch_size=12,
|
| 323 |
+
seq_length=7,
|
| 324 |
+
is_training=True,
|
| 325 |
+
use_input_mask=True,
|
| 326 |
+
use_labels=True,
|
| 327 |
+
vocab_size=99,
|
| 328 |
+
hidden_size=32,
|
| 329 |
+
num_hidden_layers=2,
|
| 330 |
+
num_attention_heads=4,
|
| 331 |
+
intermediate_size=37,
|
| 332 |
+
dropout=0.1,
|
| 333 |
+
attention_dropout=0.1,
|
| 334 |
+
max_position_embeddings=512,
|
| 335 |
+
initializer_range=0.02,
|
| 336 |
+
scope=None,
|
| 337 |
+
):
|
| 338 |
+
self.parent = parent
|
| 339 |
+
self.batch_size = batch_size
|
| 340 |
+
self.seq_length = seq_length
|
| 341 |
+
self.is_training = is_training
|
| 342 |
+
self.use_input_mask = use_input_mask
|
| 343 |
+
self.use_labels = use_labels
|
| 344 |
+
self.vocab_size = vocab_size
|
| 345 |
+
self.hidden_size = hidden_size
|
| 346 |
+
self.num_hidden_layers = num_hidden_layers
|
| 347 |
+
self.num_attention_heads = num_attention_heads
|
| 348 |
+
self.intermediate_size = intermediate_size
|
| 349 |
+
self.dropout = dropout
|
| 350 |
+
self.attention_dropout = attention_dropout
|
| 351 |
+
self.max_position_embeddings = max_position_embeddings
|
| 352 |
+
self.initializer_range = initializer_range
|
| 353 |
+
self.scope = scope
|
| 354 |
+
|
| 355 |
+
def prepare_config_and_inputs(self):
|
| 356 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 357 |
+
|
| 358 |
+
input_mask = None
|
| 359 |
+
if self.use_input_mask:
|
| 360 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 361 |
+
# make sure the first token has attention mask `1` to ensure that, after combining the causal mask, there
|
| 362 |
+
# is still at least one token being attended to for each batch.
|
| 363 |
+
# TODO: Change `random_attention_mask` in PT/TF/Flax common test file, after a discussion with the team.
|
| 364 |
+
input_mask = tf.concat(
|
| 365 |
+
[tf.ones_like(input_mask[:, :1], dtype=input_mask.dtype), input_mask[:, 1:]], axis=-1
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
config = self.get_config()
|
| 369 |
+
|
| 370 |
+
return config, input_ids, input_mask
|
| 371 |
+
|
| 372 |
+
def get_config(self):
|
| 373 |
+
return CLIPTextConfig(
|
| 374 |
+
vocab_size=self.vocab_size,
|
| 375 |
+
hidden_size=self.hidden_size,
|
| 376 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 377 |
+
num_attention_heads=self.num_attention_heads,
|
| 378 |
+
intermediate_size=self.intermediate_size,
|
| 379 |
+
dropout=self.dropout,
|
| 380 |
+
attention_dropout=self.attention_dropout,
|
| 381 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 382 |
+
initializer_range=self.initializer_range,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def create_and_check_model(self, config, input_ids, input_mask):
|
| 386 |
+
model = TFCLIPTextModel(config=config)
|
| 387 |
+
result = model(input_ids, attention_mask=input_mask, training=False)
|
| 388 |
+
result = model(input_ids, training=False)
|
| 389 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 390 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 391 |
+
|
| 392 |
+
def prepare_config_and_inputs_for_common(self):
|
| 393 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 394 |
+
config, input_ids, input_mask = config_and_inputs
|
| 395 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 396 |
+
return config, inputs_dict
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@require_tf
|
| 400 |
+
class TFCLIPTextModelTest(TFModelTesterMixin, unittest.TestCase):
|
| 401 |
+
all_model_classes = (TFCLIPTextModel,) if is_tf_available() else ()
|
| 402 |
+
test_pruning = False
|
| 403 |
+
test_head_masking = False
|
| 404 |
+
test_onnx = False
|
| 405 |
+
|
| 406 |
+
def setUp(self):
|
| 407 |
+
self.model_tester = TFCLIPTextModelTester(self)
|
| 408 |
+
self.config_tester = ConfigTester(self, config_class=CLIPTextConfig, hidden_size=37)
|
| 409 |
+
|
| 410 |
+
def test_config(self):
|
| 411 |
+
self.config_tester.run_common_tests()
|
| 412 |
+
|
| 413 |
+
def test_model(self):
|
| 414 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 415 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 416 |
+
|
| 417 |
+
def test_inputs_embeds(self):
|
| 418 |
+
# CLIP does not use inputs_embeds
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
@slow
|
| 422 |
+
def test_model_from_pretrained(self):
|
| 423 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 424 |
+
model = TFCLIPTextModel.from_pretrained(model_name)
|
| 425 |
+
self.assertIsNotNone(model)
|
| 426 |
+
|
| 427 |
+
@slow
|
| 428 |
+
def test_saved_model_creation_extended(self):
|
| 429 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 430 |
+
config.output_hidden_states = True
|
| 431 |
+
config.output_attentions = True
|
| 432 |
+
|
| 433 |
+
if hasattr(config, "use_cache"):
|
| 434 |
+
config.use_cache = True
|
| 435 |
+
|
| 436 |
+
for model_class in self.all_model_classes:
|
| 437 |
+
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
| 438 |
+
model = model_class(config)
|
| 439 |
+
num_out = len(model(class_inputs_dict))
|
| 440 |
+
|
| 441 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 442 |
+
model.save_pretrained(tmpdirname, saved_model=True)
|
| 443 |
+
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
| 444 |
+
model = keras.models.load_model(saved_model_dir)
|
| 445 |
+
outputs = model(class_inputs_dict)
|
| 446 |
+
output_hidden_states = outputs["hidden_states"]
|
| 447 |
+
output_attentions = outputs["attentions"]
|
| 448 |
+
|
| 449 |
+
# Check number of outputs
|
| 450 |
+
self.assertEqual(len(outputs), num_out)
|
| 451 |
+
|
| 452 |
+
# Check number of layers
|
| 453 |
+
expected_num_layers = getattr(
|
| 454 |
+
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Check hidden states
|
| 458 |
+
self.assertEqual(len(output_hidden_states), expected_num_layers)
|
| 459 |
+
self.assertListEqual(
|
| 460 |
+
list(output_hidden_states[0].shape[-2:]),
|
| 461 |
+
[self.model_tester.seq_length, self.model_tester.hidden_size],
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# Check attention outputs
|
| 465 |
+
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
|
| 466 |
+
|
| 467 |
+
seq_length = self.model_tester.seq_length
|
| 468 |
+
key_length = getattr(self.model_tester, "key_length", seq_length)
|
| 469 |
+
|
| 470 |
+
self.assertListEqual(
|
| 471 |
+
list(output_attentions[0].shape[-3:]),
|
| 472 |
+
[self.model_tester.num_attention_heads, seq_length, key_length],
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class TFCLIPModelTester:
|
| 477 |
+
def __init__(self, parent, is_training=True):
|
| 478 |
+
self.parent = parent
|
| 479 |
+
self.text_model_tester = TFCLIPTextModelTester(parent)
|
| 480 |
+
self.vision_model_tester = TFCLIPVisionModelTester(parent)
|
| 481 |
+
self.is_training = is_training
|
| 482 |
+
|
| 483 |
+
def prepare_config_and_inputs(self):
|
| 484 |
+
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
| 485 |
+
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 486 |
+
|
| 487 |
+
config = self.get_config()
|
| 488 |
+
|
| 489 |
+
return config, input_ids, attention_mask, pixel_values
|
| 490 |
+
|
| 491 |
+
def get_config(self):
|
| 492 |
+
return CLIPConfig.from_text_vision_configs(
|
| 493 |
+
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
| 497 |
+
model = TFCLIPModel(config)
|
| 498 |
+
result = model(input_ids, pixel_values, attention_mask, training=False)
|
| 499 |
+
self.parent.assertEqual(
|
| 500 |
+
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
| 501 |
+
)
|
| 502 |
+
self.parent.assertEqual(
|
| 503 |
+
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
def prepare_config_and_inputs_for_common(self):
|
| 507 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 508 |
+
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
| 509 |
+
inputs_dict = {
|
| 510 |
+
"input_ids": input_ids,
|
| 511 |
+
"attention_mask": attention_mask,
|
| 512 |
+
"pixel_values": pixel_values,
|
| 513 |
+
"return_loss": True,
|
| 514 |
+
}
|
| 515 |
+
return config, inputs_dict
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
@require_tf
|
| 519 |
+
class TFCLIPModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 520 |
+
all_model_classes = (TFCLIPModel,) if is_tf_available() else ()
|
| 521 |
+
pipeline_model_mapping = {"feature-extraction": TFCLIPModel} if is_tf_available() else {}
|
| 522 |
+
test_head_masking = False
|
| 523 |
+
test_pruning = False
|
| 524 |
+
test_resize_embeddings = False
|
| 525 |
+
test_attention_outputs = False
|
| 526 |
+
test_onnx = False
|
| 527 |
+
|
| 528 |
+
def setUp(self):
|
| 529 |
+
self.model_tester = TFCLIPModelTester(self)
|
| 530 |
+
|
| 531 |
+
def test_model(self):
|
| 532 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 533 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 534 |
+
|
| 535 |
+
# hidden_states are tested in individual model tests
|
| 536 |
+
def test_hidden_states_output(self):
|
| 537 |
+
pass
|
| 538 |
+
|
| 539 |
+
# input_embeds are tested in individual model tests
|
| 540 |
+
def test_inputs_embeds(self):
|
| 541 |
+
pass
|
| 542 |
+
|
| 543 |
+
# CLIPModel does not have input/output embeddings
|
| 544 |
+
def test_model_common_attributes(self):
|
| 545 |
+
pass
|
| 546 |
+
|
| 547 |
+
# overwrite from common since `TFCLIPModelTester` set `return_loss` to `True` and causes the preparation of
|
| 548 |
+
# `symbolic_inputs` failed.
|
| 549 |
+
def test_keras_save_load(self):
|
| 550 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 551 |
+
|
| 552 |
+
# remove `return_loss` to make code work
|
| 553 |
+
if self.__class__.__name__ == "TFCLIPModelTest":
|
| 554 |
+
inputs_dict.pop("return_loss", None)
|
| 555 |
+
|
| 556 |
+
tf_main_layer_classes = {
|
| 557 |
+
module_member
|
| 558 |
+
for model_class in self.all_model_classes
|
| 559 |
+
for module in (import_module(model_class.__module__),)
|
| 560 |
+
for module_member_name in dir(module)
|
| 561 |
+
if module_member_name.endswith("MainLayer")
|
| 562 |
+
# This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
|
| 563 |
+
and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
|
| 564 |
+
for module_member in (getattr(module, module_member_name),)
|
| 565 |
+
if isinstance(module_member, type)
|
| 566 |
+
and keras.layers.Layer in module_member.__bases__
|
| 567 |
+
and getattr(module_member, "_keras_serializable", False)
|
| 568 |
+
}
|
| 569 |
+
for main_layer_class in tf_main_layer_classes:
|
| 570 |
+
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
|
| 571 |
+
if "T5" in main_layer_class.__name__:
|
| 572 |
+
# Take the same values than in TFT5ModelTester for this shared layer
|
| 573 |
+
shared = TFSharedEmbeddings(99, 32, name="shared")
|
| 574 |
+
config.use_cache = inputs_dict.pop("use_cache", None)
|
| 575 |
+
main_layer = main_layer_class(config, embed_tokens=shared)
|
| 576 |
+
else:
|
| 577 |
+
main_layer = main_layer_class(config)
|
| 578 |
+
|
| 579 |
+
symbolic_inputs = {
|
| 580 |
+
name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
model = keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
|
| 584 |
+
outputs = model(inputs_dict)
|
| 585 |
+
|
| 586 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 587 |
+
filepath = os.path.join(tmpdirname, "keras_model.h5")
|
| 588 |
+
model.save(filepath)
|
| 589 |
+
if "T5" in main_layer_class.__name__:
|
| 590 |
+
model = keras.models.load_model(
|
| 591 |
+
filepath,
|
| 592 |
+
custom_objects={
|
| 593 |
+
main_layer_class.__name__: main_layer_class,
|
| 594 |
+
"TFSharedEmbeddings": TFSharedEmbeddings,
|
| 595 |
+
},
|
| 596 |
+
)
|
| 597 |
+
else:
|
| 598 |
+
model = keras.models.load_model(
|
| 599 |
+
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
|
| 600 |
+
)
|
| 601 |
+
assert isinstance(model, keras.Model)
|
| 602 |
+
after_outputs = model(inputs_dict)
|
| 603 |
+
self.assert_outputs_same(after_outputs, outputs)
|
| 604 |
+
|
| 605 |
+
@slow
|
| 606 |
+
def test_model_from_pretrained(self):
|
| 607 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 608 |
+
model = TFCLIPModel.from_pretrained(model_name)
|
| 609 |
+
self.assertIsNotNone(model)
|
| 610 |
+
|
| 611 |
+
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
|
| 612 |
+
@slow
|
| 613 |
+
def test_saved_model_creation(self):
|
| 614 |
+
pass
|
| 615 |
+
|
| 616 |
+
@unittest.skip(reason="Currently `saved_model` doesn't work with nested outputs.")
|
| 617 |
+
@slow
|
| 618 |
+
def test_saved_model_creation_extended(self):
|
| 619 |
+
pass
|
| 620 |
+
|
| 621 |
+
@unittest.skip(reason="`saved_model` doesn't work with nested outputs so no preparation happens.")
|
| 622 |
+
@slow
|
| 623 |
+
def test_prepare_serving_output(self):
|
| 624 |
+
pass
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# We will verify our results on an image of cute cats
|
| 628 |
+
def prepare_img():
|
| 629 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 630 |
+
im = Image.open(requests.get(url, stream=True).raw)
|
| 631 |
+
return im
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
@require_vision
|
| 635 |
+
@require_tf
|
| 636 |
+
class TFCLIPModelIntegrationTest(unittest.TestCase):
|
| 637 |
+
@slow
|
| 638 |
+
def test_inference(self):
|
| 639 |
+
model_name = "openai/clip-vit-base-patch32"
|
| 640 |
+
model = TFCLIPModel.from_pretrained(model_name)
|
| 641 |
+
processor = CLIPProcessor.from_pretrained(model_name)
|
| 642 |
+
|
| 643 |
+
image = prepare_img()
|
| 644 |
+
inputs = processor(
|
| 645 |
+
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="tf"
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
outputs = model(**inputs, training=False)
|
| 649 |
+
|
| 650 |
+
# verify the logits
|
| 651 |
+
self.assertEqual(
|
| 652 |
+
outputs.logits_per_image.shape,
|
| 653 |
+
tf.TensorShape((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
| 654 |
+
)
|
| 655 |
+
self.assertEqual(
|
| 656 |
+
outputs.logits_per_text.shape,
|
| 657 |
+
tf.TensorShape((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
expected_logits = tf.constant([[24.5701, 19.3049]])
|
| 661 |
+
|
| 662 |
+
tf.debugging.assert_near(outputs.logits_per_image, expected_logits, atol=1e-3)
|
docs/transformers/tests/models/clip/test_processor_clip.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import pytest
|
| 22 |
+
|
| 23 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
| 24 |
+
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
|
| 25 |
+
from transformers.testing_utils import require_vision
|
| 26 |
+
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
|
| 27 |
+
|
| 28 |
+
from ...test_processing_common import ProcessorTesterMixin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_vision_available():
|
| 32 |
+
from transformers import CLIPImageProcessor, CLIPProcessor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@require_vision
|
| 36 |
+
class CLIPProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
| 37 |
+
processor_class = CLIPProcessor
|
| 38 |
+
|
| 39 |
+
@classmethod
|
| 40 |
+
def setUpClass(cls):
|
| 41 |
+
cls.tmpdirname = tempfile.mkdtemp()
|
| 42 |
+
|
| 43 |
+
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l</w>", "w</w>", "r</w>", "t</w>", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|startoftext|>", "<|endoftext|>"] # fmt: skip
|
| 44 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
| 45 |
+
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>", ""]
|
| 46 |
+
cls.special_tokens_map = {"unk_token": "<unk>"}
|
| 47 |
+
|
| 48 |
+
cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
| 49 |
+
cls.merges_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
| 50 |
+
with open(cls.vocab_file, "w", encoding="utf-8") as fp:
|
| 51 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
| 52 |
+
with open(cls.merges_file, "w", encoding="utf-8") as fp:
|
| 53 |
+
fp.write("\n".join(merges))
|
| 54 |
+
|
| 55 |
+
image_processor_map = {
|
| 56 |
+
"do_resize": True,
|
| 57 |
+
"size": 20,
|
| 58 |
+
"do_center_crop": True,
|
| 59 |
+
"crop_size": 18,
|
| 60 |
+
"do_normalize": True,
|
| 61 |
+
"image_mean": [0.48145466, 0.4578275, 0.40821073],
|
| 62 |
+
"image_std": [0.26862954, 0.26130258, 0.27577711],
|
| 63 |
+
}
|
| 64 |
+
cls.image_processor_file = os.path.join(cls.tmpdirname, IMAGE_PROCESSOR_NAME)
|
| 65 |
+
with open(cls.image_processor_file, "w", encoding="utf-8") as fp:
|
| 66 |
+
json.dump(image_processor_map, fp)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def get_tokenizer(cls, **kwargs):
|
| 70 |
+
return CLIPTokenizer.from_pretrained(cls.tmpdirname, **kwargs)
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def get_rust_tokenizer(cls, **kwargs):
|
| 74 |
+
return CLIPTokenizerFast.from_pretrained(cls.tmpdirname, **kwargs)
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def get_image_processor(cls, **kwargs):
|
| 78 |
+
return CLIPImageProcessor.from_pretrained(cls.tmpdirname, **kwargs)
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def tearDownClass(cls):
|
| 82 |
+
shutil.rmtree(cls.tmpdirname)
|
| 83 |
+
|
| 84 |
+
def test_save_load_pretrained_default(self):
|
| 85 |
+
tokenizer_slow = self.get_tokenizer()
|
| 86 |
+
tokenizer_fast = self.get_rust_tokenizer()
|
| 87 |
+
image_processor = self.get_image_processor()
|
| 88 |
+
|
| 89 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 90 |
+
processor_slow = CLIPProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
|
| 91 |
+
processor_slow.save_pretrained(tmpdir)
|
| 92 |
+
processor_slow = CLIPProcessor.from_pretrained(tmpdir, use_fast=False)
|
| 93 |
+
|
| 94 |
+
processor_fast = CLIPProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
|
| 95 |
+
processor_fast.save_pretrained(tmpdir)
|
| 96 |
+
processor_fast = CLIPProcessor.from_pretrained(tmpdir)
|
| 97 |
+
|
| 98 |
+
self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
|
| 99 |
+
self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
|
| 100 |
+
self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
|
| 101 |
+
self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer)
|
| 102 |
+
self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast)
|
| 103 |
+
|
| 104 |
+
self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
|
| 105 |
+
self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
|
| 106 |
+
self.assertIsInstance(processor_slow.image_processor, CLIPImageProcessor)
|
| 107 |
+
self.assertIsInstance(processor_fast.image_processor, CLIPImageProcessor)
|
| 108 |
+
|
| 109 |
+
def test_save_load_pretrained_additional_features(self):
|
| 110 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 111 |
+
processor = CLIPProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
|
| 112 |
+
processor.save_pretrained(tmpdir)
|
| 113 |
+
|
| 114 |
+
tokenizer_add_kwargs = CLIPTokenizer.from_pretrained(tmpdir, bos_token="(BOS)", eos_token="(EOS)")
|
| 115 |
+
image_processor_add_kwargs = CLIPImageProcessor.from_pretrained(
|
| 116 |
+
tmpdir, do_normalize=False, padding_value=1.0
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
processor = CLIPProcessor.from_pretrained(
|
| 120 |
+
tmpdir, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
| 124 |
+
self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
|
| 125 |
+
|
| 126 |
+
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
| 127 |
+
self.assertIsInstance(processor.image_processor, CLIPImageProcessor)
|
| 128 |
+
|
| 129 |
+
def test_image_processor(self):
|
| 130 |
+
image_processor = self.get_image_processor()
|
| 131 |
+
tokenizer = self.get_tokenizer()
|
| 132 |
+
|
| 133 |
+
processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 134 |
+
|
| 135 |
+
image_input = self.prepare_image_inputs()
|
| 136 |
+
|
| 137 |
+
input_image_proc = image_processor(image_input, return_tensors="np")
|
| 138 |
+
input_processor = processor(images=image_input, return_tensors="np")
|
| 139 |
+
|
| 140 |
+
for key in input_image_proc.keys():
|
| 141 |
+
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
|
| 142 |
+
|
| 143 |
+
def test_tokenizer(self):
|
| 144 |
+
image_processor = self.get_image_processor()
|
| 145 |
+
tokenizer = self.get_tokenizer()
|
| 146 |
+
|
| 147 |
+
processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 148 |
+
|
| 149 |
+
input_str = "lower newer"
|
| 150 |
+
|
| 151 |
+
encoded_processor = processor(text=input_str)
|
| 152 |
+
|
| 153 |
+
encoded_tok = tokenizer(input_str)
|
| 154 |
+
|
| 155 |
+
for key in encoded_tok.keys():
|
| 156 |
+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
| 157 |
+
|
| 158 |
+
def test_processor(self):
|
| 159 |
+
image_processor = self.get_image_processor()
|
| 160 |
+
tokenizer = self.get_tokenizer()
|
| 161 |
+
|
| 162 |
+
processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 163 |
+
|
| 164 |
+
input_str = "lower newer"
|
| 165 |
+
image_input = self.prepare_image_inputs()
|
| 166 |
+
|
| 167 |
+
inputs = processor(text=input_str, images=image_input)
|
| 168 |
+
|
| 169 |
+
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
| 170 |
+
|
| 171 |
+
# test if it raises when no input is passed
|
| 172 |
+
with pytest.raises(ValueError):
|
| 173 |
+
processor()
|
| 174 |
+
|
| 175 |
+
def test_tokenizer_decode(self):
|
| 176 |
+
image_processor = self.get_image_processor()
|
| 177 |
+
tokenizer = self.get_tokenizer()
|
| 178 |
+
|
| 179 |
+
processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 180 |
+
|
| 181 |
+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
| 182 |
+
|
| 183 |
+
decoded_processor = processor.batch_decode(predicted_ids)
|
| 184 |
+
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
| 185 |
+
|
| 186 |
+
self.assertListEqual(decoded_tok, decoded_processor)
|
| 187 |
+
|
| 188 |
+
def test_model_input_names(self):
|
| 189 |
+
image_processor = self.get_image_processor()
|
| 190 |
+
tokenizer = self.get_tokenizer()
|
| 191 |
+
|
| 192 |
+
processor = CLIPProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 193 |
+
|
| 194 |
+
input_str = "lower newer"
|
| 195 |
+
image_input = self.prepare_image_inputs()
|
| 196 |
+
|
| 197 |
+
inputs = processor(text=input_str, images=image_input)
|
| 198 |
+
|
| 199 |
+
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
docs/transformers/tests/models/clip/test_tokenization_clip.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import unittest
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
|
| 21 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
| 22 |
+
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
|
| 23 |
+
from transformers.testing_utils import require_ftfy, require_tokenizers
|
| 24 |
+
|
| 25 |
+
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@require_tokenizers
|
| 29 |
+
class CLIPTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 30 |
+
from_pretrained_id = "openai/clip-vit-base-patch32"
|
| 31 |
+
tokenizer_class = CLIPTokenizer
|
| 32 |
+
rust_tokenizer_class = CLIPTokenizerFast
|
| 33 |
+
test_rust_tokenizer = True
|
| 34 |
+
from_pretrained_kwargs = {}
|
| 35 |
+
test_seq2seq = False
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def setUpClass(cls):
|
| 39 |
+
super().setUpClass()
|
| 40 |
+
|
| 41 |
+
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l</w>", "w</w>", "r</w>", "t</w>", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|startoftext|>", "<|endoftext|>"] # fmt: skip
|
| 42 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
| 43 |
+
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>"]
|
| 44 |
+
cls.special_tokens_map = {"unk_token": "<unk>"}
|
| 45 |
+
|
| 46 |
+
cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
| 47 |
+
cls.merges_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
| 48 |
+
with open(cls.vocab_file, "w", encoding="utf-8") as fp:
|
| 49 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
| 50 |
+
with open(cls.merges_file, "w", encoding="utf-8") as fp:
|
| 51 |
+
fp.write("\n".join(merges))
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
@use_cache_if_possible
|
| 55 |
+
@lru_cache(maxsize=64)
|
| 56 |
+
def get_tokenizer(cls, pretrained_name=None, **kwargs):
|
| 57 |
+
kwargs.update(cls.special_tokens_map)
|
| 58 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 59 |
+
return CLIPTokenizer.from_pretrained(pretrained_name, **kwargs)
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
@use_cache_if_possible
|
| 63 |
+
@lru_cache(maxsize=64)
|
| 64 |
+
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs):
|
| 65 |
+
kwargs.update(cls.special_tokens_map)
|
| 66 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 67 |
+
return CLIPTokenizerFast.from_pretrained(pretrained_name, **kwargs)
|
| 68 |
+
|
| 69 |
+
def get_input_output_texts(self, tokenizer):
|
| 70 |
+
input_text = "lower newer"
|
| 71 |
+
output_text = "lower newer"
|
| 72 |
+
return input_text, output_text
|
| 73 |
+
|
| 74 |
+
def test_full_tokenizer(self):
|
| 75 |
+
tokenizer = CLIPTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
| 76 |
+
text = "lower newer"
|
| 77 |
+
bpe_tokens = ["lo", "w", "er</w>", "n", "e", "w", "er</w>"]
|
| 78 |
+
tokens = tokenizer.tokenize(text)
|
| 79 |
+
self.assertListEqual(tokens, bpe_tokens)
|
| 80 |
+
|
| 81 |
+
input_tokens = tokens + [tokenizer.unk_token]
|
| 82 |
+
input_bpe_tokens = [10, 2, 16, 9, 3, 2, 16, 20]
|
| 83 |
+
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
| 84 |
+
|
| 85 |
+
@require_ftfy
|
| 86 |
+
def test_check_encoding_slow_fast(self):
|
| 87 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 88 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 89 |
+
tokenizer_s = self.get_tokenizer(pretrained_name, **kwargs)
|
| 90 |
+
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs)
|
| 91 |
+
|
| 92 |
+
text = "A\n'll 11p223RF☆ho!!to?'d'd''d of a cat to-$''d."
|
| 93 |
+
text_tokenized_s = tokenizer_s.tokenize(text)
|
| 94 |
+
text_tokenized_r = tokenizer_r.tokenize(text)
|
| 95 |
+
|
| 96 |
+
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
| 97 |
+
|
| 98 |
+
# Test that the tokenization is identical on an example containing a character (Latin Small Letter A
|
| 99 |
+
# with Tilde) encoded in 2 different ways
|
| 100 |
+
text = "xa\u0303y" + " " + "x\xe3y"
|
| 101 |
+
text_tokenized_s = tokenizer_s.tokenize(text)
|
| 102 |
+
text_tokenized_r = tokenizer_r.tokenize(text)
|
| 103 |
+
|
| 104 |
+
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
| 105 |
+
|
| 106 |
+
# Test that the tokenization is identical on unicode of space type
|
| 107 |
+
spaces_unicodes = [
|
| 108 |
+
"\u0009", # (horizontal tab, '\t')
|
| 109 |
+
"\u000b", # (vertical tab)
|
| 110 |
+
"\u000c", # (form feed)
|
| 111 |
+
"\u0020", # (space, ' ')
|
| 112 |
+
"\u200e", # (left-to-right mark):w
|
| 113 |
+
"\u200f", # (right-to-left mark)
|
| 114 |
+
]
|
| 115 |
+
for unicode_seq in spaces_unicodes:
|
| 116 |
+
text_tokenized_s = tokenizer_s.tokenize(unicode_seq)
|
| 117 |
+
text_tokenized_r = tokenizer_r.tokenize(unicode_seq)
|
| 118 |
+
|
| 119 |
+
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
| 120 |
+
|
| 121 |
+
# Test that the tokenization is identical on unicode of line break type
|
| 122 |
+
line_break_unicodes = [
|
| 123 |
+
"\u000a", # (line feed, '\n')
|
| 124 |
+
"\r\n", # (carriage return and line feed, '\r\n')
|
| 125 |
+
"\u000d", # (carriage return, '\r')
|
| 126 |
+
"\r", # (carriage return, '\r')
|
| 127 |
+
"\u000d", # (carriage return, '\r')
|
| 128 |
+
"\u2028", # (line separator)
|
| 129 |
+
"\u2029", # (paragraph separator)
|
| 130 |
+
# "\u0085", # (next line)
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
# The tokenization is not identical for the character "\u0085" (next line). The slow version using ftfy transforms
|
| 134 |
+
# it into the Horizontal Ellipsis character "…" ("\u2026") while the fast version transforms it into a
|
| 135 |
+
# space (and thus into an empty list).
|
| 136 |
+
|
| 137 |
+
for unicode_seq in line_break_unicodes:
|
| 138 |
+
text_tokenized_s = tokenizer_s.tokenize(unicode_seq)
|
| 139 |
+
text_tokenized_r = tokenizer_r.tokenize(unicode_seq)
|
| 140 |
+
|
| 141 |
+
self.assertListEqual(text_tokenized_s, text_tokenized_r)
|
| 142 |
+
|
| 143 |
+
def test_offsets_mapping_with_different_add_prefix_space_argument(self):
|
| 144 |
+
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space`
|
| 145 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 146 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 147 |
+
text_of_1_token = "hello" # `hello` is a token in the vocabulary of `pretrained_name`
|
| 148 |
+
text = f"{text_of_1_token} {text_of_1_token}"
|
| 149 |
+
|
| 150 |
+
tokenizer_r = self.get_rust_tokenizer(
|
| 151 |
+
pretrained_name,
|
| 152 |
+
use_fast=True,
|
| 153 |
+
)
|
| 154 |
+
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
| 155 |
+
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
| 156 |
+
self.assertEqual(
|
| 157 |
+
encoding.offset_mapping[1],
|
| 158 |
+
(len(text_of_1_token) + 1, len(text_of_1_token) + 1 + len(text_of_1_token)),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
text = f" {text}"
|
| 162 |
+
|
| 163 |
+
tokenizer_r = self.get_rust_tokenizer(
|
| 164 |
+
pretrained_name,
|
| 165 |
+
use_fast=True,
|
| 166 |
+
)
|
| 167 |
+
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
| 168 |
+
self.assertEqual(encoding.offset_mapping[0], (1, 1 + len(text_of_1_token)))
|
| 169 |
+
self.assertEqual(
|
| 170 |
+
encoding.offset_mapping[1],
|
| 171 |
+
(1 + len(text_of_1_token) + 1, 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def test_log_warning(self):
|
| 175 |
+
# Test related to the breaking change introduced in transformers v4.17.0
|
| 176 |
+
# We need to check that an error in raised when the user try to load a previous version of the tokenizer.
|
| 177 |
+
with self.assertRaises(ValueError) as context:
|
| 178 |
+
self.get_rust_tokenizer("robot-test/old-clip-tokenizer")
|
| 179 |
+
|
| 180 |
+
self.assertTrue(
|
| 181 |
+
context.exception.args[0].startswith(
|
| 182 |
+
"The `backend_tokenizer` provided does not match the expected format."
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
@require_ftfy
|
| 187 |
+
def test_tokenization_python_rust_equals(self):
|
| 188 |
+
super().test_tokenization_python_rust_equals()
|
| 189 |
+
|
| 190 |
+
@unittest.skip(reason="CLIP always lower cases letters")
|
| 191 |
+
def test_added_tokens_do_lower_case(self):
|
| 192 |
+
pass
|
docs/transformers/tests/models/clipseg/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/clipseg/test_modeling_clipseg.py
ADDED
|
@@ -0,0 +1,714 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch CLIPSeg model."""
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import os
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import requests
|
| 23 |
+
|
| 24 |
+
from transformers import CLIPSegConfig, CLIPSegProcessor, CLIPSegTextConfig, CLIPSegVisionConfig
|
| 25 |
+
from transformers.testing_utils import (
|
| 26 |
+
require_torch,
|
| 27 |
+
require_vision,
|
| 28 |
+
slow,
|
| 29 |
+
torch_device,
|
| 30 |
+
)
|
| 31 |
+
from transformers.utils import is_torch_available, is_vision_available
|
| 32 |
+
|
| 33 |
+
from ...test_configuration_common import ConfigTester
|
| 34 |
+
from ...test_modeling_common import (
|
| 35 |
+
ModelTesterMixin,
|
| 36 |
+
_config_zero_init,
|
| 37 |
+
floats_tensor,
|
| 38 |
+
ids_tensor,
|
| 39 |
+
random_attention_mask,
|
| 40 |
+
)
|
| 41 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if is_torch_available():
|
| 45 |
+
import torch
|
| 46 |
+
from torch import nn
|
| 47 |
+
|
| 48 |
+
from transformers import CLIPSegForImageSegmentation, CLIPSegModel, CLIPSegTextModel, CLIPSegVisionModel
|
| 49 |
+
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if is_vision_available():
|
| 53 |
+
from PIL import Image
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class CLIPSegVisionModelTester:
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
parent,
|
| 60 |
+
batch_size=12,
|
| 61 |
+
image_size=30,
|
| 62 |
+
patch_size=2,
|
| 63 |
+
num_channels=3,
|
| 64 |
+
is_training=True,
|
| 65 |
+
hidden_size=32,
|
| 66 |
+
num_hidden_layers=2,
|
| 67 |
+
num_attention_heads=4,
|
| 68 |
+
intermediate_size=37,
|
| 69 |
+
dropout=0.1,
|
| 70 |
+
attention_dropout=0.1,
|
| 71 |
+
initializer_range=0.02,
|
| 72 |
+
scope=None,
|
| 73 |
+
):
|
| 74 |
+
self.parent = parent
|
| 75 |
+
self.batch_size = batch_size
|
| 76 |
+
self.image_size = image_size
|
| 77 |
+
self.patch_size = patch_size
|
| 78 |
+
self.num_channels = num_channels
|
| 79 |
+
self.is_training = is_training
|
| 80 |
+
self.hidden_size = hidden_size
|
| 81 |
+
self.num_hidden_layers = num_hidden_layers
|
| 82 |
+
self.num_attention_heads = num_attention_heads
|
| 83 |
+
self.intermediate_size = intermediate_size
|
| 84 |
+
self.dropout = dropout
|
| 85 |
+
self.attention_dropout = attention_dropout
|
| 86 |
+
self.initializer_range = initializer_range
|
| 87 |
+
self.scope = scope
|
| 88 |
+
|
| 89 |
+
# in ViT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token)
|
| 90 |
+
num_patches = (image_size // patch_size) ** 2
|
| 91 |
+
self.seq_length = num_patches + 1
|
| 92 |
+
|
| 93 |
+
def prepare_config_and_inputs(self):
|
| 94 |
+
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
| 95 |
+
config = self.get_config()
|
| 96 |
+
|
| 97 |
+
return config, pixel_values
|
| 98 |
+
|
| 99 |
+
def get_config(self):
|
| 100 |
+
return CLIPSegVisionConfig(
|
| 101 |
+
image_size=self.image_size,
|
| 102 |
+
patch_size=self.patch_size,
|
| 103 |
+
num_channels=self.num_channels,
|
| 104 |
+
hidden_size=self.hidden_size,
|
| 105 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 106 |
+
num_attention_heads=self.num_attention_heads,
|
| 107 |
+
intermediate_size=self.intermediate_size,
|
| 108 |
+
dropout=self.dropout,
|
| 109 |
+
attention_dropout=self.attention_dropout,
|
| 110 |
+
initializer_range=self.initializer_range,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def create_and_check_model(self, config, pixel_values):
|
| 114 |
+
model = CLIPSegVisionModel(config=config)
|
| 115 |
+
model.to(torch_device)
|
| 116 |
+
model.eval()
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
result = model(pixel_values)
|
| 119 |
+
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
|
| 120 |
+
image_size = (self.image_size, self.image_size)
|
| 121 |
+
patch_size = (self.patch_size, self.patch_size)
|
| 122 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
| 123 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
|
| 124 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 125 |
+
|
| 126 |
+
def prepare_config_and_inputs_for_common(self):
|
| 127 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 128 |
+
config, pixel_values = config_and_inputs
|
| 129 |
+
inputs_dict = {"pixel_values": pixel_values}
|
| 130 |
+
return config, inputs_dict
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@require_torch
|
| 134 |
+
class CLIPSegVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
| 135 |
+
"""
|
| 136 |
+
Here we also overwrite some of the tests of test_modeling_common.py, as CLIPSeg does not use input_ids, inputs_embeds,
|
| 137 |
+
attention_mask and seq_length.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
all_model_classes = (CLIPSegVisionModel,) if is_torch_available() else ()
|
| 141 |
+
fx_compatible = False
|
| 142 |
+
test_pruning = False
|
| 143 |
+
test_resize_embeddings = False
|
| 144 |
+
test_head_masking = False
|
| 145 |
+
|
| 146 |
+
def setUp(self):
|
| 147 |
+
self.model_tester = CLIPSegVisionModelTester(self)
|
| 148 |
+
self.config_tester = ConfigTester(
|
| 149 |
+
self, config_class=CLIPSegVisionConfig, has_text_modality=False, hidden_size=37
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def test_config(self):
|
| 153 |
+
self.config_tester.run_common_tests()
|
| 154 |
+
|
| 155 |
+
@unittest.skip(reason="CLIPSeg does not use inputs_embeds")
|
| 156 |
+
def test_inputs_embeds(self):
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
def test_model_get_set_embeddings(self):
|
| 160 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 161 |
+
|
| 162 |
+
for model_class in self.all_model_classes:
|
| 163 |
+
model = model_class(config)
|
| 164 |
+
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
| 165 |
+
x = model.get_output_embeddings()
|
| 166 |
+
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
| 167 |
+
|
| 168 |
+
def test_forward_signature(self):
|
| 169 |
+
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
| 170 |
+
|
| 171 |
+
for model_class in self.all_model_classes:
|
| 172 |
+
model = model_class(config)
|
| 173 |
+
signature = inspect.signature(model.forward)
|
| 174 |
+
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
| 175 |
+
arg_names = [*signature.parameters.keys()]
|
| 176 |
+
|
| 177 |
+
expected_arg_names = ["pixel_values"]
|
| 178 |
+
self.assertListEqual(arg_names[:1], expected_arg_names)
|
| 179 |
+
|
| 180 |
+
def test_model(self):
|
| 181 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 182 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 183 |
+
|
| 184 |
+
@unittest.skip
|
| 185 |
+
def test_training(self):
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
@unittest.skip
|
| 189 |
+
def test_training_gradient_checkpointing(self):
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
@unittest.skip(
|
| 193 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 194 |
+
)
|
| 195 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 196 |
+
pass
|
| 197 |
+
|
| 198 |
+
@unittest.skip(
|
| 199 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 200 |
+
)
|
| 201 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 202 |
+
pass
|
| 203 |
+
|
| 204 |
+
@slow
|
| 205 |
+
def test_model_from_pretrained(self):
|
| 206 |
+
model_name = "CIDAS/clipseg-rd64-refined"
|
| 207 |
+
model = CLIPSegVisionModel.from_pretrained(model_name)
|
| 208 |
+
self.assertIsNotNone(model)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class CLIPSegTextModelTester:
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
parent,
|
| 215 |
+
batch_size=12,
|
| 216 |
+
seq_length=7,
|
| 217 |
+
is_training=True,
|
| 218 |
+
use_input_mask=True,
|
| 219 |
+
use_labels=True,
|
| 220 |
+
vocab_size=99,
|
| 221 |
+
hidden_size=32,
|
| 222 |
+
num_hidden_layers=2,
|
| 223 |
+
num_attention_heads=4,
|
| 224 |
+
intermediate_size=37,
|
| 225 |
+
dropout=0.1,
|
| 226 |
+
attention_dropout=0.1,
|
| 227 |
+
max_position_embeddings=512,
|
| 228 |
+
initializer_range=0.02,
|
| 229 |
+
scope=None,
|
| 230 |
+
):
|
| 231 |
+
self.parent = parent
|
| 232 |
+
self.batch_size = batch_size
|
| 233 |
+
self.seq_length = seq_length
|
| 234 |
+
self.is_training = is_training
|
| 235 |
+
self.use_input_mask = use_input_mask
|
| 236 |
+
self.use_labels = use_labels
|
| 237 |
+
self.vocab_size = vocab_size
|
| 238 |
+
self.hidden_size = hidden_size
|
| 239 |
+
self.num_hidden_layers = num_hidden_layers
|
| 240 |
+
self.num_attention_heads = num_attention_heads
|
| 241 |
+
self.intermediate_size = intermediate_size
|
| 242 |
+
self.dropout = dropout
|
| 243 |
+
self.attention_dropout = attention_dropout
|
| 244 |
+
self.max_position_embeddings = max_position_embeddings
|
| 245 |
+
self.initializer_range = initializer_range
|
| 246 |
+
self.scope = scope
|
| 247 |
+
|
| 248 |
+
def prepare_config_and_inputs(self):
|
| 249 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 250 |
+
|
| 251 |
+
input_mask = None
|
| 252 |
+
if self.use_input_mask:
|
| 253 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 254 |
+
|
| 255 |
+
if input_mask is not None:
|
| 256 |
+
batch_size, seq_length = input_mask.shape
|
| 257 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 258 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 259 |
+
input_mask[batch_idx, :start_index] = 1
|
| 260 |
+
input_mask[batch_idx, start_index:] = 0
|
| 261 |
+
|
| 262 |
+
config = self.get_config()
|
| 263 |
+
|
| 264 |
+
return config, input_ids, input_mask
|
| 265 |
+
|
| 266 |
+
def get_config(self):
|
| 267 |
+
return CLIPSegTextConfig(
|
| 268 |
+
vocab_size=self.vocab_size,
|
| 269 |
+
hidden_size=self.hidden_size,
|
| 270 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 271 |
+
num_attention_heads=self.num_attention_heads,
|
| 272 |
+
intermediate_size=self.intermediate_size,
|
| 273 |
+
dropout=self.dropout,
|
| 274 |
+
attention_dropout=self.attention_dropout,
|
| 275 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 276 |
+
initializer_range=self.initializer_range,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
def create_and_check_model(self, config, input_ids, input_mask):
|
| 280 |
+
model = CLIPSegTextModel(config=config)
|
| 281 |
+
model.to(torch_device)
|
| 282 |
+
model.eval()
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
result = model(input_ids, attention_mask=input_mask)
|
| 285 |
+
result = model(input_ids)
|
| 286 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 287 |
+
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
| 288 |
+
|
| 289 |
+
def prepare_config_and_inputs_for_common(self):
|
| 290 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 291 |
+
config, input_ids, input_mask = config_and_inputs
|
| 292 |
+
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
| 293 |
+
return config, inputs_dict
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@require_torch
|
| 297 |
+
class CLIPSegTextModelTest(ModelTesterMixin, unittest.TestCase):
|
| 298 |
+
all_model_classes = (CLIPSegTextModel,) if is_torch_available() else ()
|
| 299 |
+
fx_compatible = False
|
| 300 |
+
test_pruning = False
|
| 301 |
+
test_head_masking = False
|
| 302 |
+
model_split_percents = [0.5, 0.8, 0.9]
|
| 303 |
+
|
| 304 |
+
def setUp(self):
|
| 305 |
+
self.model_tester = CLIPSegTextModelTester(self)
|
| 306 |
+
self.config_tester = ConfigTester(self, config_class=CLIPSegTextConfig, hidden_size=37)
|
| 307 |
+
|
| 308 |
+
def test_config(self):
|
| 309 |
+
self.config_tester.run_common_tests()
|
| 310 |
+
|
| 311 |
+
def test_model(self):
|
| 312 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 313 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 314 |
+
|
| 315 |
+
@unittest.skip
|
| 316 |
+
def test_training(self):
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
@unittest.skip
|
| 320 |
+
def test_training_gradient_checkpointing(self):
|
| 321 |
+
pass
|
| 322 |
+
|
| 323 |
+
@unittest.skip(
|
| 324 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 325 |
+
)
|
| 326 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 327 |
+
pass
|
| 328 |
+
|
| 329 |
+
@unittest.skip(
|
| 330 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 331 |
+
)
|
| 332 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 333 |
+
pass
|
| 334 |
+
|
| 335 |
+
@unittest.skip(reason="CLIPSeg does not use inputs_embeds")
|
| 336 |
+
def test_inputs_embeds(self):
|
| 337 |
+
pass
|
| 338 |
+
|
| 339 |
+
@slow
|
| 340 |
+
def test_model_from_pretrained(self):
|
| 341 |
+
model_name = "CIDAS/clipseg-rd64-refined"
|
| 342 |
+
model = CLIPSegTextModel.from_pretrained(model_name)
|
| 343 |
+
self.assertIsNotNone(model)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class CLIPSegModelTester:
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
parent,
|
| 350 |
+
text_kwargs=None,
|
| 351 |
+
vision_kwargs=None,
|
| 352 |
+
is_training=True,
|
| 353 |
+
# This should respect the `num_hidden_layers` in `CLIPSegVisionModelTester`
|
| 354 |
+
extract_layers=(1,),
|
| 355 |
+
):
|
| 356 |
+
if text_kwargs is None:
|
| 357 |
+
text_kwargs = {}
|
| 358 |
+
if vision_kwargs is None:
|
| 359 |
+
vision_kwargs = {}
|
| 360 |
+
|
| 361 |
+
self.parent = parent
|
| 362 |
+
self.text_model_tester = CLIPSegTextModelTester(parent, **text_kwargs)
|
| 363 |
+
self.vision_model_tester = CLIPSegVisionModelTester(parent, **vision_kwargs)
|
| 364 |
+
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
| 365 |
+
self.is_training = is_training
|
| 366 |
+
self.extract_layers = extract_layers
|
| 367 |
+
|
| 368 |
+
def prepare_config_and_inputs(self):
|
| 369 |
+
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
| 370 |
+
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
| 371 |
+
|
| 372 |
+
config = self.get_config()
|
| 373 |
+
|
| 374 |
+
return config, input_ids, attention_mask, pixel_values
|
| 375 |
+
|
| 376 |
+
def get_config(self):
|
| 377 |
+
return CLIPSegConfig.from_text_vision_configs(
|
| 378 |
+
self.text_model_tester.get_config(),
|
| 379 |
+
self.vision_model_tester.get_config(),
|
| 380 |
+
projection_dim=64,
|
| 381 |
+
reduce_dim=32,
|
| 382 |
+
extract_layers=self.extract_layers,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
| 386 |
+
model = CLIPSegModel(config).to(torch_device).eval()
|
| 387 |
+
with torch.no_grad():
|
| 388 |
+
result = model(input_ids, pixel_values, attention_mask)
|
| 389 |
+
self.parent.assertEqual(
|
| 390 |
+
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
| 391 |
+
)
|
| 392 |
+
self.parent.assertEqual(
|
| 393 |
+
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def create_and_check_model_for_image_segmentation(self, config, input_ids, attention_maks, pixel_values):
|
| 397 |
+
model = CLIPSegForImageSegmentation(config).to(torch_device).eval()
|
| 398 |
+
with torch.no_grad():
|
| 399 |
+
result = model(input_ids, pixel_values)
|
| 400 |
+
self.parent.assertEqual(
|
| 401 |
+
result.logits.shape,
|
| 402 |
+
(
|
| 403 |
+
self.vision_model_tester.batch_size,
|
| 404 |
+
self.vision_model_tester.image_size,
|
| 405 |
+
self.vision_model_tester.image_size,
|
| 406 |
+
),
|
| 407 |
+
)
|
| 408 |
+
self.parent.assertEqual(
|
| 409 |
+
result.conditional_embeddings.shape, (self.text_model_tester.batch_size, config.projection_dim)
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def prepare_config_and_inputs_for_common(self):
|
| 413 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 414 |
+
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
| 415 |
+
inputs_dict = {
|
| 416 |
+
"input_ids": input_ids,
|
| 417 |
+
"attention_mask": attention_mask,
|
| 418 |
+
"pixel_values": pixel_values,
|
| 419 |
+
}
|
| 420 |
+
return config, inputs_dict
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
@require_torch
|
| 424 |
+
class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 425 |
+
all_model_classes = (CLIPSegModel, CLIPSegForImageSegmentation) if is_torch_available() else ()
|
| 426 |
+
pipeline_model_mapping = {"feature-extraction": CLIPSegModel} if is_torch_available() else {}
|
| 427 |
+
fx_compatible = False
|
| 428 |
+
test_head_masking = False
|
| 429 |
+
test_pruning = False
|
| 430 |
+
test_resize_embeddings = False
|
| 431 |
+
test_attention_outputs = False
|
| 432 |
+
|
| 433 |
+
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
| 434 |
+
# CLIPSegForImageSegmentation requires special treatment
|
| 435 |
+
if return_labels:
|
| 436 |
+
if model_class.__name__ == "CLIPSegForImageSegmentation":
|
| 437 |
+
batch_size, _, height, width = inputs_dict["pixel_values"].shape
|
| 438 |
+
inputs_dict["labels"] = torch.zeros(
|
| 439 |
+
[batch_size, height, width], device=torch_device, dtype=torch.float
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
return inputs_dict
|
| 443 |
+
|
| 444 |
+
def setUp(self):
|
| 445 |
+
self.model_tester = CLIPSegModelTester(self)
|
| 446 |
+
common_properties = ["projection_dim", "logit_scale_init_value"]
|
| 447 |
+
self.config_tester = ConfigTester(
|
| 448 |
+
self, config_class=CLIPSegConfig, has_text_modality=False, common_properties=common_properties
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
def test_model(self):
|
| 452 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 453 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 454 |
+
|
| 455 |
+
def test_config(self):
|
| 456 |
+
self.config_tester.run_common_tests()
|
| 457 |
+
|
| 458 |
+
def test_model_for_image_segmentation(self):
|
| 459 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 460 |
+
self.model_tester.create_and_check_model_for_image_segmentation(*config_and_inputs)
|
| 461 |
+
|
| 462 |
+
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
| 463 |
+
def test_hidden_states_output(self):
|
| 464 |
+
pass
|
| 465 |
+
|
| 466 |
+
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
|
| 467 |
+
def test_inputs_embeds(self):
|
| 468 |
+
pass
|
| 469 |
+
|
| 470 |
+
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
| 471 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 472 |
+
pass
|
| 473 |
+
|
| 474 |
+
@unittest.skip(reason="CLIPSegModel does not have input/output embeddings")
|
| 475 |
+
def test_model_get_set_embeddings(self):
|
| 476 |
+
pass
|
| 477 |
+
|
| 478 |
+
@unittest.skip(
|
| 479 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 480 |
+
)
|
| 481 |
+
def test_training_gradient_checkpointing(self):
|
| 482 |
+
pass
|
| 483 |
+
|
| 484 |
+
@unittest.skip(
|
| 485 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 486 |
+
)
|
| 487 |
+
def test_training_gradient_checkpointing_use_reentrant(self):
|
| 488 |
+
pass
|
| 489 |
+
|
| 490 |
+
@unittest.skip(
|
| 491 |
+
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
| 492 |
+
)
|
| 493 |
+
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
| 494 |
+
pass
|
| 495 |
+
|
| 496 |
+
# override as the some parameters require custom initialization
|
| 497 |
+
def test_initialization(self):
|
| 498 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 499 |
+
|
| 500 |
+
configs_no_init = _config_zero_init(config)
|
| 501 |
+
for model_class in self.all_model_classes:
|
| 502 |
+
model = model_class(config=configs_no_init)
|
| 503 |
+
for name, param in model.named_parameters():
|
| 504 |
+
if param.requires_grad:
|
| 505 |
+
# check if `logit_scale` is initialized as per the original implementation
|
| 506 |
+
if "logit_scale" in name:
|
| 507 |
+
self.assertAlmostEqual(
|
| 508 |
+
param.data.item(),
|
| 509 |
+
np.log(1 / 0.07),
|
| 510 |
+
delta=1e-3,
|
| 511 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 512 |
+
)
|
| 513 |
+
elif "film" in name or "transposed_conv" in name or "reduce" in name:
|
| 514 |
+
# those parameters use PyTorch' default nn.Linear initialization scheme
|
| 515 |
+
pass
|
| 516 |
+
else:
|
| 517 |
+
self.assertIn(
|
| 518 |
+
((param.data.mean() * 1e9).round() / 1e9).item(),
|
| 519 |
+
[0.0, 1.0],
|
| 520 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
def _create_and_check_torchscript(self, config, inputs_dict):
|
| 524 |
+
if not self.test_torchscript:
|
| 525 |
+
self.skipTest(reason="test_torchscript is set to False")
|
| 526 |
+
|
| 527 |
+
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
| 528 |
+
configs_no_init.torchscript = True
|
| 529 |
+
configs_no_init.return_dict = False
|
| 530 |
+
for model_class in self.all_model_classes:
|
| 531 |
+
model = model_class(config=configs_no_init)
|
| 532 |
+
model.to(torch_device)
|
| 533 |
+
model.eval()
|
| 534 |
+
|
| 535 |
+
try:
|
| 536 |
+
input_ids = inputs_dict["input_ids"]
|
| 537 |
+
pixel_values = inputs_dict["pixel_values"] # CLIPSeg needs pixel_values
|
| 538 |
+
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
| 539 |
+
except RuntimeError:
|
| 540 |
+
self.fail("Couldn't trace module.")
|
| 541 |
+
|
| 542 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 543 |
+
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
| 544 |
+
|
| 545 |
+
try:
|
| 546 |
+
torch.jit.save(traced_model, pt_file_name)
|
| 547 |
+
except Exception:
|
| 548 |
+
self.fail("Couldn't save module.")
|
| 549 |
+
|
| 550 |
+
try:
|
| 551 |
+
loaded_model = torch.jit.load(pt_file_name)
|
| 552 |
+
except Exception:
|
| 553 |
+
self.fail("Couldn't load module.")
|
| 554 |
+
|
| 555 |
+
model.to(torch_device)
|
| 556 |
+
model.eval()
|
| 557 |
+
|
| 558 |
+
loaded_model.to(torch_device)
|
| 559 |
+
loaded_model.eval()
|
| 560 |
+
|
| 561 |
+
model_state_dict = model.state_dict()
|
| 562 |
+
loaded_model_state_dict = loaded_model.state_dict()
|
| 563 |
+
|
| 564 |
+
non_persistent_buffers = {}
|
| 565 |
+
for key in loaded_model_state_dict.keys():
|
| 566 |
+
if key not in model_state_dict.keys():
|
| 567 |
+
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
| 568 |
+
|
| 569 |
+
loaded_model_state_dict = {
|
| 570 |
+
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
| 574 |
+
|
| 575 |
+
model_buffers = list(model.buffers())
|
| 576 |
+
for non_persistent_buffer in non_persistent_buffers.values():
|
| 577 |
+
found_buffer = False
|
| 578 |
+
for i, model_buffer in enumerate(model_buffers):
|
| 579 |
+
if torch.equal(non_persistent_buffer, model_buffer):
|
| 580 |
+
found_buffer = True
|
| 581 |
+
break
|
| 582 |
+
|
| 583 |
+
self.assertTrue(found_buffer)
|
| 584 |
+
model_buffers.pop(i)
|
| 585 |
+
|
| 586 |
+
models_equal = True
|
| 587 |
+
for layer_name, p1 in model_state_dict.items():
|
| 588 |
+
p2 = loaded_model_state_dict[layer_name]
|
| 589 |
+
if p1.data.ne(p2.data).sum() > 0:
|
| 590 |
+
models_equal = False
|
| 591 |
+
|
| 592 |
+
self.assertTrue(models_equal)
|
| 593 |
+
|
| 594 |
+
def test_load_vision_text_config(self):
|
| 595 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 596 |
+
|
| 597 |
+
# Save CLIPSegConfig and check if we can load CLIPSegVisionConfig from it
|
| 598 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 599 |
+
config.save_pretrained(tmp_dir_name)
|
| 600 |
+
vision_config = CLIPSegVisionConfig.from_pretrained(tmp_dir_name)
|
| 601 |
+
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
| 602 |
+
|
| 603 |
+
# Save CLIPSegConfig and check if we can load CLIPSegTextConfig from it
|
| 604 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 605 |
+
config.save_pretrained(tmp_dir_name)
|
| 606 |
+
text_config = CLIPSegTextConfig.from_pretrained(tmp_dir_name)
|
| 607 |
+
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
| 608 |
+
|
| 609 |
+
def test_training(self):
|
| 610 |
+
if not self.model_tester.is_training:
|
| 611 |
+
self.skipTest(reason="Training test is skipped as the model was not trained")
|
| 612 |
+
|
| 613 |
+
for model_class in self.all_model_classes:
|
| 614 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 615 |
+
config.return_dict = True
|
| 616 |
+
|
| 617 |
+
if model_class.__name__ in MODEL_MAPPING_NAMES.values():
|
| 618 |
+
continue
|
| 619 |
+
|
| 620 |
+
print("Model class:", model_class)
|
| 621 |
+
|
| 622 |
+
model = model_class(config)
|
| 623 |
+
model.to(torch_device)
|
| 624 |
+
model.train()
|
| 625 |
+
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
| 626 |
+
for k, v in inputs.items():
|
| 627 |
+
print(k, v.shape)
|
| 628 |
+
loss = model(**inputs).loss
|
| 629 |
+
loss.backward()
|
| 630 |
+
|
| 631 |
+
@slow
|
| 632 |
+
def test_model_from_pretrained(self):
|
| 633 |
+
model_name = "CIDAS/clipseg-rd64-refined"
|
| 634 |
+
model = CLIPSegModel.from_pretrained(model_name)
|
| 635 |
+
self.assertIsNotNone(model)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
# We will verify our results on an image of cute cats
|
| 639 |
+
def prepare_img():
|
| 640 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 641 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 642 |
+
return image
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
@require_vision
|
| 646 |
+
@require_torch
|
| 647 |
+
class CLIPSegModelIntegrationTest(unittest.TestCase):
|
| 648 |
+
@slow
|
| 649 |
+
def test_inference_image_segmentation(self):
|
| 650 |
+
model_name = "CIDAS/clipseg-rd64-refined"
|
| 651 |
+
processor = CLIPSegProcessor.from_pretrained(model_name)
|
| 652 |
+
model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(torch_device)
|
| 653 |
+
|
| 654 |
+
image = prepare_img()
|
| 655 |
+
texts = ["a cat", "a remote", "a blanket"]
|
| 656 |
+
inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt").to(torch_device)
|
| 657 |
+
|
| 658 |
+
# forward pass
|
| 659 |
+
with torch.no_grad():
|
| 660 |
+
outputs = model(**inputs)
|
| 661 |
+
|
| 662 |
+
# verify the predicted masks
|
| 663 |
+
self.assertEqual(
|
| 664 |
+
outputs.logits.shape,
|
| 665 |
+
torch.Size((3, 352, 352)),
|
| 666 |
+
)
|
| 667 |
+
expected_masks_slice = torch.tensor(
|
| 668 |
+
[[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
|
| 669 |
+
).to(torch_device)
|
| 670 |
+
|
| 671 |
+
torch.testing.assert_close(outputs.logits[0, :3, :3], expected_masks_slice, rtol=1e-3, atol=1e-3)
|
| 672 |
+
|
| 673 |
+
# verify conditional and pooled output
|
| 674 |
+
expected_conditional = torch.tensor([0.5601, -0.0314, 0.1980]).to(torch_device)
|
| 675 |
+
expected_pooled_output = torch.tensor([0.5036, -0.2681, -0.2644]).to(torch_device)
|
| 676 |
+
torch.testing.assert_close(outputs.conditional_embeddings[0, :3], expected_conditional, rtol=1e-3, atol=1e-3)
|
| 677 |
+
torch.testing.assert_close(outputs.pooled_output[0, :3], expected_pooled_output, rtol=1e-3, atol=1e-3)
|
| 678 |
+
|
| 679 |
+
@slow
|
| 680 |
+
def test_inference_interpolate_pos_encoding(self):
|
| 681 |
+
# ViT models have an `interpolate_pos_encoding` argument in their forward method,
|
| 682 |
+
# allowing to interpolate the pre-trained position embeddings in order to use
|
| 683 |
+
# the model on higher resolutions. The DINO model by Facebook AI leverages this
|
| 684 |
+
# to visualize self-attention on higher resolution images.
|
| 685 |
+
model = CLIPSegModel.from_pretrained("openai/clip-vit-base-patch32").to(torch_device)
|
| 686 |
+
|
| 687 |
+
processor = CLIPSegProcessor.from_pretrained(
|
| 688 |
+
"openai/clip-vit-base-patch32", size={"height": 180, "width": 180}, crop_size={"height": 180, "width": 180}
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
| 692 |
+
inputs = processor(text="what's in the image", images=image, return_tensors="pt").to(torch_device)
|
| 693 |
+
|
| 694 |
+
# interpolate_pos_encodiung false should return value error
|
| 695 |
+
with self.assertRaises(ValueError, msg="doesn't match model"):
|
| 696 |
+
with torch.no_grad():
|
| 697 |
+
model(**inputs, interpolate_pos_encoding=False)
|
| 698 |
+
|
| 699 |
+
# forward pass
|
| 700 |
+
with torch.no_grad():
|
| 701 |
+
outputs = model(**inputs, interpolate_pos_encoding=True)
|
| 702 |
+
|
| 703 |
+
# verify the logits
|
| 704 |
+
expected_shape = torch.Size((1, 26, 768))
|
| 705 |
+
|
| 706 |
+
self.assertEqual(outputs.vision_model_output.last_hidden_state.shape, expected_shape)
|
| 707 |
+
|
| 708 |
+
expected_slice = torch.tensor(
|
| 709 |
+
[[-0.1538, 0.0322, -0.3235], [0.2893, 0.1135, -0.5708], [0.0461, 0.1540, -0.6018]]
|
| 710 |
+
).to(torch_device)
|
| 711 |
+
|
| 712 |
+
torch.testing.assert_close(
|
| 713 |
+
outputs.vision_model_output.last_hidden_state[0, :3, :3], expected_slice, rtol=1e-4, atol=1e-4
|
| 714 |
+
)
|
docs/transformers/tests/models/clipseg/test_processor_clipseg.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import pytest
|
| 22 |
+
|
| 23 |
+
from transformers import CLIPTokenizer, CLIPTokenizerFast
|
| 24 |
+
from transformers.models.clip.tokenization_clip import VOCAB_FILES_NAMES
|
| 25 |
+
from transformers.testing_utils import require_vision
|
| 26 |
+
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available
|
| 27 |
+
|
| 28 |
+
from ...test_processing_common import ProcessorTesterMixin
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if is_vision_available():
|
| 32 |
+
from transformers import CLIPSegProcessor, ViTImageProcessor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@require_vision
|
| 36 |
+
class CLIPSegProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
| 37 |
+
processor_class = CLIPSegProcessor
|
| 38 |
+
|
| 39 |
+
def setUp(self):
|
| 40 |
+
self.tmpdirname = tempfile.mkdtemp()
|
| 41 |
+
|
| 42 |
+
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "lo", "l</w>", "w</w>", "r</w>", "t</w>", "low</w>", "er</w>", "lowest</w>", "newer</w>", "wider", "<unk>", "<|startoftext|>", "<|endoftext|>"] # fmt: skip
|
| 43 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
| 44 |
+
merges = ["#version: 0.2", "l o", "lo w</w>", "e r</w>", ""]
|
| 45 |
+
self.special_tokens_map = {"unk_token": "<unk>"}
|
| 46 |
+
|
| 47 |
+
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
| 48 |
+
self.merges_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
| 49 |
+
with open(self.vocab_file, "w", encoding="utf-8") as fp:
|
| 50 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
| 51 |
+
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
| 52 |
+
fp.write("\n".join(merges))
|
| 53 |
+
|
| 54 |
+
image_processor_map = {
|
| 55 |
+
"do_resize": True,
|
| 56 |
+
"size": 20,
|
| 57 |
+
"do_center_crop": True,
|
| 58 |
+
"crop_size": 18,
|
| 59 |
+
"do_normalize": True,
|
| 60 |
+
"image_mean": [0.48145466, 0.4578275, 0.40821073],
|
| 61 |
+
"image_std": [0.26862954, 0.26130258, 0.27577711],
|
| 62 |
+
}
|
| 63 |
+
self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME)
|
| 64 |
+
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
|
| 65 |
+
json.dump(image_processor_map, fp)
|
| 66 |
+
|
| 67 |
+
def get_tokenizer(self, **kwargs):
|
| 68 |
+
return CLIPTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
| 69 |
+
|
| 70 |
+
def get_rust_tokenizer(self, **kwargs):
|
| 71 |
+
return CLIPTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)
|
| 72 |
+
|
| 73 |
+
def get_image_processor(self, **kwargs):
|
| 74 |
+
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
| 75 |
+
|
| 76 |
+
def tearDown(self):
|
| 77 |
+
shutil.rmtree(self.tmpdirname)
|
| 78 |
+
|
| 79 |
+
def test_save_load_pretrained_default(self):
|
| 80 |
+
tokenizer_slow = self.get_tokenizer()
|
| 81 |
+
tokenizer_fast = self.get_rust_tokenizer()
|
| 82 |
+
image_processor = self.get_image_processor()
|
| 83 |
+
|
| 84 |
+
processor_slow = CLIPSegProcessor(tokenizer=tokenizer_slow, image_processor=image_processor)
|
| 85 |
+
processor_slow.save_pretrained(self.tmpdirname)
|
| 86 |
+
processor_slow = CLIPSegProcessor.from_pretrained(self.tmpdirname, use_fast=False)
|
| 87 |
+
|
| 88 |
+
processor_fast = CLIPSegProcessor(tokenizer=tokenizer_fast, image_processor=image_processor)
|
| 89 |
+
processor_fast.save_pretrained(self.tmpdirname)
|
| 90 |
+
processor_fast = CLIPSegProcessor.from_pretrained(self.tmpdirname)
|
| 91 |
+
|
| 92 |
+
self.assertEqual(processor_slow.tokenizer.get_vocab(), tokenizer_slow.get_vocab())
|
| 93 |
+
self.assertEqual(processor_fast.tokenizer.get_vocab(), tokenizer_fast.get_vocab())
|
| 94 |
+
self.assertEqual(tokenizer_slow.get_vocab(), tokenizer_fast.get_vocab())
|
| 95 |
+
self.assertIsInstance(processor_slow.tokenizer, CLIPTokenizer)
|
| 96 |
+
self.assertIsInstance(processor_fast.tokenizer, CLIPTokenizerFast)
|
| 97 |
+
|
| 98 |
+
self.assertEqual(processor_slow.image_processor.to_json_string(), image_processor.to_json_string())
|
| 99 |
+
self.assertEqual(processor_fast.image_processor.to_json_string(), image_processor.to_json_string())
|
| 100 |
+
self.assertIsInstance(processor_slow.image_processor, ViTImageProcessor)
|
| 101 |
+
self.assertIsInstance(processor_fast.image_processor, ViTImageProcessor)
|
| 102 |
+
|
| 103 |
+
def test_save_load_pretrained_additional_features(self):
|
| 104 |
+
processor = CLIPSegProcessor(tokenizer=self.get_tokenizer(), image_processor=self.get_image_processor())
|
| 105 |
+
processor.save_pretrained(self.tmpdirname)
|
| 106 |
+
|
| 107 |
+
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
|
| 108 |
+
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
| 109 |
+
|
| 110 |
+
processor = CLIPSegProcessor.from_pretrained(
|
| 111 |
+
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
| 115 |
+
self.assertIsInstance(processor.tokenizer, CLIPTokenizerFast)
|
| 116 |
+
|
| 117 |
+
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
| 118 |
+
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
|
| 119 |
+
|
| 120 |
+
def test_image_processor(self):
|
| 121 |
+
image_processor = self.get_image_processor()
|
| 122 |
+
tokenizer = self.get_tokenizer()
|
| 123 |
+
|
| 124 |
+
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 125 |
+
|
| 126 |
+
image_input = self.prepare_image_inputs()
|
| 127 |
+
|
| 128 |
+
input_feat_extract = image_processor(image_input, return_tensors="np")
|
| 129 |
+
input_processor = processor(images=image_input, return_tensors="np")
|
| 130 |
+
|
| 131 |
+
for key in input_feat_extract.keys():
|
| 132 |
+
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
| 133 |
+
|
| 134 |
+
def test_tokenizer(self):
|
| 135 |
+
image_processor = self.get_image_processor()
|
| 136 |
+
tokenizer = self.get_tokenizer()
|
| 137 |
+
|
| 138 |
+
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 139 |
+
|
| 140 |
+
input_str = "lower newer"
|
| 141 |
+
|
| 142 |
+
encoded_processor = processor(text=input_str)
|
| 143 |
+
|
| 144 |
+
encoded_tok = tokenizer(input_str)
|
| 145 |
+
|
| 146 |
+
for key in encoded_tok.keys():
|
| 147 |
+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
| 148 |
+
|
| 149 |
+
def test_processor_text(self):
|
| 150 |
+
image_processor = self.get_image_processor()
|
| 151 |
+
tokenizer = self.get_tokenizer()
|
| 152 |
+
|
| 153 |
+
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 154 |
+
|
| 155 |
+
input_str = "lower newer"
|
| 156 |
+
image_input = self.prepare_image_inputs()
|
| 157 |
+
|
| 158 |
+
inputs = processor(text=input_str, images=image_input)
|
| 159 |
+
|
| 160 |
+
self.assertListEqual(list(inputs.keys()), ["input_ids", "attention_mask", "pixel_values"])
|
| 161 |
+
|
| 162 |
+
# test if it raises when no input is passed
|
| 163 |
+
with pytest.raises(ValueError):
|
| 164 |
+
processor()
|
| 165 |
+
|
| 166 |
+
def test_processor_visual_prompt(self):
|
| 167 |
+
image_processor = self.get_image_processor()
|
| 168 |
+
tokenizer = self.get_tokenizer()
|
| 169 |
+
|
| 170 |
+
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 171 |
+
|
| 172 |
+
image_input = self.prepare_image_inputs()
|
| 173 |
+
visual_prompt_input = self.prepare_image_inputs()
|
| 174 |
+
|
| 175 |
+
inputs = processor(images=image_input, visual_prompt=visual_prompt_input)
|
| 176 |
+
|
| 177 |
+
self.assertListEqual(list(inputs.keys()), ["pixel_values", "conditional_pixel_values"])
|
| 178 |
+
|
| 179 |
+
# test if it raises when no input is passed
|
| 180 |
+
with pytest.raises(ValueError):
|
| 181 |
+
processor()
|
| 182 |
+
|
| 183 |
+
def test_tokenizer_decode(self):
|
| 184 |
+
image_processor = self.get_image_processor()
|
| 185 |
+
tokenizer = self.get_tokenizer()
|
| 186 |
+
|
| 187 |
+
processor = CLIPSegProcessor(tokenizer=tokenizer, image_processor=image_processor)
|
| 188 |
+
|
| 189 |
+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
| 190 |
+
|
| 191 |
+
decoded_processor = processor.batch_decode(predicted_ids)
|
| 192 |
+
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
| 193 |
+
|
| 194 |
+
self.assertListEqual(decoded_tok, decoded_processor)
|
docs/transformers/tests/models/clvp/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/clvp/test_feature_extraction_clvp.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 HuggingFace Inc.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import itertools
|
| 16 |
+
import os
|
| 17 |
+
import random
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
from datasets import Audio, load_dataset
|
| 23 |
+
|
| 24 |
+
from transformers import ClvpFeatureExtractor
|
| 25 |
+
from transformers.testing_utils import (
|
| 26 |
+
check_json_file_has_correct_format,
|
| 27 |
+
cleanup,
|
| 28 |
+
require_torch,
|
| 29 |
+
slow,
|
| 30 |
+
torch_device,
|
| 31 |
+
)
|
| 32 |
+
from transformers.utils.import_utils import is_torch_available
|
| 33 |
+
|
| 34 |
+
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if is_torch_available():
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
global_rng = random.Random()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.floats_list
|
| 44 |
+
def floats_list(shape, scale=1.0, rng=None, name=None):
|
| 45 |
+
"""Creates a random float32 tensor"""
|
| 46 |
+
if rng is None:
|
| 47 |
+
rng = global_rng
|
| 48 |
+
|
| 49 |
+
values = []
|
| 50 |
+
for batch_idx in range(shape[0]):
|
| 51 |
+
values.append([])
|
| 52 |
+
for _ in range(shape[1]):
|
| 53 |
+
values[-1].append(rng.random() * scale)
|
| 54 |
+
|
| 55 |
+
return values
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@require_torch
|
| 59 |
+
class ClvpFeatureExtractionTester:
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
parent,
|
| 63 |
+
batch_size=7,
|
| 64 |
+
min_seq_length=400,
|
| 65 |
+
max_seq_length=2000,
|
| 66 |
+
feature_size=10,
|
| 67 |
+
hop_length=160,
|
| 68 |
+
chunk_length=8,
|
| 69 |
+
padding_value=0.0,
|
| 70 |
+
sampling_rate=4_000,
|
| 71 |
+
return_attention_mask=False,
|
| 72 |
+
):
|
| 73 |
+
self.parent = parent
|
| 74 |
+
self.batch_size = batch_size
|
| 75 |
+
self.min_seq_length = min_seq_length
|
| 76 |
+
self.max_seq_length = max_seq_length
|
| 77 |
+
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
| 78 |
+
self.padding_value = padding_value
|
| 79 |
+
self.sampling_rate = sampling_rate
|
| 80 |
+
self.return_attention_mask = return_attention_mask
|
| 81 |
+
self.feature_size = feature_size
|
| 82 |
+
self.chunk_length = chunk_length
|
| 83 |
+
self.hop_length = hop_length
|
| 84 |
+
|
| 85 |
+
def prepare_feat_extract_dict(self):
|
| 86 |
+
return {
|
| 87 |
+
"feature_size": self.feature_size,
|
| 88 |
+
"hop_length": self.hop_length,
|
| 89 |
+
"chunk_length": self.chunk_length,
|
| 90 |
+
"padding_value": self.padding_value,
|
| 91 |
+
"sampling_rate": self.sampling_rate,
|
| 92 |
+
"return_attention_mask": self.return_attention_mask,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester.prepare_inputs_for_common
|
| 96 |
+
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
| 97 |
+
def _flatten(list_of_lists):
|
| 98 |
+
return list(itertools.chain(*list_of_lists))
|
| 99 |
+
|
| 100 |
+
if equal_length:
|
| 101 |
+
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
| 102 |
+
else:
|
| 103 |
+
# make sure that inputs increase in size
|
| 104 |
+
speech_inputs = [
|
| 105 |
+
floats_list((x, self.feature_size))
|
| 106 |
+
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
| 107 |
+
]
|
| 108 |
+
if numpify:
|
| 109 |
+
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
| 110 |
+
return speech_inputs
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@require_torch
|
| 114 |
+
class ClvpFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
| 115 |
+
feature_extraction_class = ClvpFeatureExtractor
|
| 116 |
+
|
| 117 |
+
def setUp(self):
|
| 118 |
+
self.feat_extract_tester = ClvpFeatureExtractionTester(self)
|
| 119 |
+
|
| 120 |
+
def tearDown(self):
|
| 121 |
+
super().tearDown()
|
| 122 |
+
# clean-up as much as possible GPU memory occupied by PyTorch
|
| 123 |
+
cleanup(torch_device)
|
| 124 |
+
|
| 125 |
+
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_feat_extract_from_and_save_pretrained
|
| 126 |
+
def test_feat_extract_from_and_save_pretrained(self):
|
| 127 |
+
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
| 128 |
+
|
| 129 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 130 |
+
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
|
| 131 |
+
check_json_file_has_correct_format(saved_file)
|
| 132 |
+
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
|
| 133 |
+
|
| 134 |
+
dict_first = feat_extract_first.to_dict()
|
| 135 |
+
dict_second = feat_extract_second.to_dict()
|
| 136 |
+
mel_1 = feat_extract_first.mel_filters
|
| 137 |
+
mel_2 = feat_extract_second.mel_filters
|
| 138 |
+
self.assertTrue(np.allclose(mel_1, mel_2))
|
| 139 |
+
self.assertEqual(dict_first, dict_second)
|
| 140 |
+
|
| 141 |
+
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_feat_extract_to_json_file
|
| 142 |
+
def test_feat_extract_to_json_file(self):
|
| 143 |
+
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
|
| 144 |
+
|
| 145 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
| 146 |
+
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
|
| 147 |
+
feat_extract_first.to_json_file(json_file_path)
|
| 148 |
+
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
|
| 149 |
+
|
| 150 |
+
dict_first = feat_extract_first.to_dict()
|
| 151 |
+
dict_second = feat_extract_second.to_dict()
|
| 152 |
+
mel_1 = feat_extract_first.mel_filters
|
| 153 |
+
mel_2 = feat_extract_second.mel_filters
|
| 154 |
+
self.assertTrue(np.allclose(mel_1, mel_2))
|
| 155 |
+
self.assertEqual(dict_first, dict_second)
|
| 156 |
+
|
| 157 |
+
def test_call(self):
|
| 158 |
+
# Tests that all call wrap to encode_plus and batch_encode_plus
|
| 159 |
+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
| 160 |
+
# create three inputs of length 800, 1000, and 1200
|
| 161 |
+
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
| 162 |
+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
| 163 |
+
|
| 164 |
+
# Test feature size
|
| 165 |
+
input_features = feature_extractor(np_speech_inputs, padding="max_length", return_tensors="np").input_features
|
| 166 |
+
self.assertTrue(input_features.ndim == 3)
|
| 167 |
+
self.assertTrue(input_features.shape[-2] == feature_extractor.feature_size)
|
| 168 |
+
|
| 169 |
+
# Test not batched input
|
| 170 |
+
encoded_sequences_1 = feature_extractor(speech_inputs[0], return_tensors="np").input_features
|
| 171 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs[0], return_tensors="np").input_features
|
| 172 |
+
self.assertTrue(np.allclose(encoded_sequences_1, encoded_sequences_2, atol=1e-3))
|
| 173 |
+
|
| 174 |
+
# Test batched
|
| 175 |
+
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
| 176 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
| 177 |
+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
| 178 |
+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
| 179 |
+
|
| 180 |
+
# Test 2-D numpy arrays are batched.
|
| 181 |
+
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
| 182 |
+
np_speech_inputs = np.asarray(speech_inputs)
|
| 183 |
+
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
| 184 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
| 185 |
+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
| 186 |
+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
| 187 |
+
|
| 188 |
+
# Test truncation required
|
| 189 |
+
speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)]
|
| 190 |
+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
| 191 |
+
|
| 192 |
+
speech_inputs_truncated = [x[: feature_extractor.n_samples] for x in speech_inputs]
|
| 193 |
+
np_speech_inputs_truncated = [np.asarray(speech_input) for speech_input in speech_inputs_truncated]
|
| 194 |
+
|
| 195 |
+
encoded_sequences_1 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
| 196 |
+
encoded_sequences_2 = feature_extractor(np_speech_inputs_truncated, return_tensors="np").input_features
|
| 197 |
+
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
| 198 |
+
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
| 199 |
+
|
| 200 |
+
# Copied from transformers.tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTest.test_double_precision_pad
|
| 201 |
+
def test_double_precision_pad(self):
|
| 202 |
+
import torch
|
| 203 |
+
|
| 204 |
+
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
| 205 |
+
np_speech_inputs = np.random.rand(100, 32).astype(np.float64)
|
| 206 |
+
py_speech_inputs = np_speech_inputs.tolist()
|
| 207 |
+
|
| 208 |
+
for inputs in [py_speech_inputs, np_speech_inputs]:
|
| 209 |
+
np_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="np")
|
| 210 |
+
self.assertTrue(np_processed.input_features.dtype == np.float32)
|
| 211 |
+
pt_processed = feature_extractor.pad([{"input_features": inputs}], return_tensors="pt")
|
| 212 |
+
self.assertTrue(pt_processed.input_features.dtype == torch.float32)
|
| 213 |
+
|
| 214 |
+
def _load_datasamples(self, num_samples):
|
| 215 |
+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 216 |
+
ds = ds.cast_column("audio", Audio(sampling_rate=22050))
|
| 217 |
+
# automatic decoding with librispeech
|
| 218 |
+
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
| 219 |
+
|
| 220 |
+
return [x["array"] for x in speech_samples], [x["sampling_rate"] for x in speech_samples]
|
| 221 |
+
|
| 222 |
+
@slow
|
| 223 |
+
def test_integration(self):
|
| 224 |
+
# fmt: off
|
| 225 |
+
EXPECTED_INPUT_FEATURES = torch.tensor(
|
| 226 |
+
[
|
| 227 |
+
0.9271, 1.1405, 1.4419, 1.2470, 1.2438, 1.1787, 1.0595, 1.0570, 1.1070,
|
| 228 |
+
1.2205, 1.2376, 1.2997, 1.1131, 1.0843, 1.0459, 1.1858, 1.2323, 1.3582,
|
| 229 |
+
1.3401, 1.3770, 1.4173, 1.3381, 1.2291, 1.0854, 1.2116, 1.1873, 1.2178,
|
| 230 |
+
1.2137, 1.3001, 1.4274
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
# fmt: on
|
| 234 |
+
|
| 235 |
+
input_speech, sr = self._load_datasamples(1)
|
| 236 |
+
|
| 237 |
+
feature_extractor = ClvpFeatureExtractor.from_pretrained("susnato/clvp_dev")
|
| 238 |
+
input_features = feature_extractor(input_speech, sampling_rate=sr[0], return_tensors="pt").input_features
|
| 239 |
+
self.assertEqual(input_features.shape, (1, 80, 517))
|
| 240 |
+
torch.testing.assert_close(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, rtol=1e-4, atol=1e-4)
|
docs/transformers/tests/models/clvp/test_modeling_clvp.py
ADDED
|
@@ -0,0 +1,640 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""Testing suite for the PyTorch Clvp model."""
|
| 15 |
+
|
| 16 |
+
import tempfile
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
import datasets
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
from transformers import ClvpConfig, ClvpDecoderConfig, ClvpEncoderConfig
|
| 23 |
+
from transformers.testing_utils import (
|
| 24 |
+
cleanup,
|
| 25 |
+
require_torch,
|
| 26 |
+
slow,
|
| 27 |
+
torch_device,
|
| 28 |
+
)
|
| 29 |
+
from transformers.utils import is_torch_available
|
| 30 |
+
|
| 31 |
+
from ...generation.test_utils import GenerationTesterMixin
|
| 32 |
+
from ...test_configuration_common import ConfigTester
|
| 33 |
+
from ...test_modeling_common import (
|
| 34 |
+
ModelTesterMixin,
|
| 35 |
+
_config_zero_init,
|
| 36 |
+
ids_tensor,
|
| 37 |
+
random_attention_mask,
|
| 38 |
+
)
|
| 39 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_torch_available():
|
| 43 |
+
import torch
|
| 44 |
+
|
| 45 |
+
from transformers import ClvpEncoder, ClvpForCausalLM, ClvpModel, ClvpModelForConditionalGeneration
|
| 46 |
+
|
| 47 |
+
from transformers import ClvpFeatureExtractor, ClvpTokenizer
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ClvpEncoderTester:
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
parent,
|
| 54 |
+
batch_size=2,
|
| 55 |
+
seq_length=7,
|
| 56 |
+
is_training=False,
|
| 57 |
+
use_input_mask=True,
|
| 58 |
+
use_labels=True,
|
| 59 |
+
vocab_size=50,
|
| 60 |
+
hidden_size=128,
|
| 61 |
+
projection_dim=16,
|
| 62 |
+
num_hidden_layers=2,
|
| 63 |
+
num_attention_heads=4,
|
| 64 |
+
intermediate_size=32,
|
| 65 |
+
dropout=0.1,
|
| 66 |
+
attention_dropout=0.1,
|
| 67 |
+
initializer_range=0.02,
|
| 68 |
+
scope=None,
|
| 69 |
+
):
|
| 70 |
+
self.parent = parent
|
| 71 |
+
self.batch_size = batch_size
|
| 72 |
+
self.seq_length = seq_length
|
| 73 |
+
self.is_training = is_training
|
| 74 |
+
self.use_input_mask = use_input_mask
|
| 75 |
+
self.use_labels = use_labels
|
| 76 |
+
self.vocab_size = vocab_size
|
| 77 |
+
self.hidden_size = hidden_size
|
| 78 |
+
self.projection_dim = projection_dim
|
| 79 |
+
self.num_hidden_layers = num_hidden_layers
|
| 80 |
+
self.num_attention_heads = num_attention_heads
|
| 81 |
+
self.intermediate_size = intermediate_size
|
| 82 |
+
self.dropout = dropout
|
| 83 |
+
self.attention_dropout = attention_dropout
|
| 84 |
+
self.initializer_range = initializer_range
|
| 85 |
+
self.scope = scope
|
| 86 |
+
self.bos_token_id = vocab_size - 1
|
| 87 |
+
self.eos_token_id = vocab_size - 1
|
| 88 |
+
|
| 89 |
+
def get_config(self):
|
| 90 |
+
encoder_config = ClvpEncoderConfig(
|
| 91 |
+
vocab_size=self.vocab_size,
|
| 92 |
+
hidden_size=self.hidden_size,
|
| 93 |
+
projection_dim=self.projection_dim,
|
| 94 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 95 |
+
num_attention_heads=self.num_attention_heads,
|
| 96 |
+
intermediate_size=self.intermediate_size,
|
| 97 |
+
dropout=self.dropout,
|
| 98 |
+
attention_dropout=self.attention_dropout,
|
| 99 |
+
initializer_range=self.initializer_range,
|
| 100 |
+
bos_token_id=self.bos_token_id,
|
| 101 |
+
eos_token_id=self.eos_token_id,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return encoder_config
|
| 105 |
+
|
| 106 |
+
def prepare_config_and_inputs(self):
|
| 107 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 108 |
+
|
| 109 |
+
input_mask = None
|
| 110 |
+
if self.use_input_mask:
|
| 111 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 112 |
+
|
| 113 |
+
if input_mask is not None:
|
| 114 |
+
batch_size, seq_length = input_mask.shape
|
| 115 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 116 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 117 |
+
input_mask[batch_idx, :start_index] = 1
|
| 118 |
+
input_mask[batch_idx, start_index:] = 0
|
| 119 |
+
|
| 120 |
+
encoder_config = self.get_config()
|
| 121 |
+
|
| 122 |
+
return encoder_config, input_ids, input_mask
|
| 123 |
+
|
| 124 |
+
def prepare_config_and_inputs_for_common(self):
|
| 125 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 126 |
+
speech_config, input_ids, input_mask = config_and_inputs
|
| 127 |
+
inputs_dict = {"input_ids": input_ids.to(torch_device), "attention_mask": input_mask.to(torch_device)}
|
| 128 |
+
return speech_config, inputs_dict
|
| 129 |
+
|
| 130 |
+
def create_and_check_model(self, speech_config, input_ids, input_mask):
|
| 131 |
+
text_config = ClvpEncoderConfig(
|
| 132 |
+
vocab_size=self.vocab_size,
|
| 133 |
+
hidden_size=self.hidden_size,
|
| 134 |
+
projection_dim=self.projection_dim,
|
| 135 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 136 |
+
num_attention_heads=self.num_attention_heads,
|
| 137 |
+
intermediate_size=self.intermediate_size,
|
| 138 |
+
dropout=self.dropout,
|
| 139 |
+
attention_dropout=self.attention_dropout,
|
| 140 |
+
initializer_range=self.initializer_range,
|
| 141 |
+
)
|
| 142 |
+
text_encoder_model = ClvpEncoder(config=text_config)
|
| 143 |
+
text_encoder_model.to(torch_device)
|
| 144 |
+
text_encoder_model.eval()
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
result = text_encoder_model(input_ids, attention_mask=input_mask)
|
| 147 |
+
result = text_encoder_model(input_ids)
|
| 148 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 149 |
+
self.parent.assertEqual(result[0].shape, (self.batch_size, self.projection_dim))
|
| 150 |
+
|
| 151 |
+
# now check with speech config
|
| 152 |
+
speech_encoder_model = ClvpEncoder(config=speech_config)
|
| 153 |
+
speech_encoder_model.to(torch_device)
|
| 154 |
+
speech_encoder_model.eval()
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
result = speech_encoder_model(input_ids, attention_mask=input_mask)
|
| 157 |
+
result = speech_encoder_model(input_ids)
|
| 158 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 159 |
+
self.parent.assertEqual(result[0].shape, (self.batch_size, self.projection_dim))
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@require_torch
|
| 163 |
+
class ClvpEncoderTest(ModelTesterMixin, unittest.TestCase):
|
| 164 |
+
all_model_classes = (ClvpEncoder,) if is_torch_available() else ()
|
| 165 |
+
test_pruning = False
|
| 166 |
+
test_head_masking = False
|
| 167 |
+
test_torchscript = False
|
| 168 |
+
|
| 169 |
+
def setUp(self):
|
| 170 |
+
self.model_tester = ClvpEncoderTester(self)
|
| 171 |
+
self.encoder_config_tester = ConfigTester(self, config_class=ClvpEncoderConfig, hidden_size=32)
|
| 172 |
+
|
| 173 |
+
def tearDown(self):
|
| 174 |
+
super().tearDown()
|
| 175 |
+
# clean-up as much as possible GPU memory occupied by PyTorch
|
| 176 |
+
cleanup(torch_device)
|
| 177 |
+
|
| 178 |
+
def test_config(self):
|
| 179 |
+
self.encoder_config_tester.run_common_tests()
|
| 180 |
+
|
| 181 |
+
def test_model(self):
|
| 182 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 183 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 184 |
+
|
| 185 |
+
@unittest.skip(reason="ClvpEncoder does not output loss")
|
| 186 |
+
def test_training(self):
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
@unittest.skip(reason="ClvpEncoder does not output loss")
|
| 190 |
+
def test_training_gradient_checkpointing(self):
|
| 191 |
+
pass
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class ClvpDecoderTester:
|
| 195 |
+
def __init__(
|
| 196 |
+
self,
|
| 197 |
+
parent,
|
| 198 |
+
batch_size=2,
|
| 199 |
+
seq_length=3,
|
| 200 |
+
is_training=False,
|
| 201 |
+
vocab_size=300,
|
| 202 |
+
max_position_embeddings=256,
|
| 203 |
+
max_text_tokens=256,
|
| 204 |
+
use_input_mask=True,
|
| 205 |
+
hidden_size=128,
|
| 206 |
+
num_hidden_layers=2,
|
| 207 |
+
num_attention_heads=2,
|
| 208 |
+
bos_token_id=97,
|
| 209 |
+
eos_token_id=98,
|
| 210 |
+
relative_attention_num_buckets=4,
|
| 211 |
+
relative_attention_max_distance=16,
|
| 212 |
+
):
|
| 213 |
+
self.parent = parent
|
| 214 |
+
self.batch_size = batch_size
|
| 215 |
+
self.seq_length = seq_length
|
| 216 |
+
self.is_training = is_training
|
| 217 |
+
self.vocab_size = vocab_size
|
| 218 |
+
self.max_position_embeddings = max_position_embeddings
|
| 219 |
+
self.max_text_tokens = max_text_tokens
|
| 220 |
+
self.use_input_mask = use_input_mask
|
| 221 |
+
self.hidden_size = hidden_size
|
| 222 |
+
self.num_attention_heads = num_attention_heads
|
| 223 |
+
self.num_hidden_layers = num_hidden_layers
|
| 224 |
+
self.bos_token_id = bos_token_id
|
| 225 |
+
self.eos_token_id = eos_token_id
|
| 226 |
+
self.relative_attention_num_buckets = relative_attention_num_buckets
|
| 227 |
+
self.relative_attention_max_distance = relative_attention_max_distance
|
| 228 |
+
|
| 229 |
+
def get_config(self):
|
| 230 |
+
decoder_config = ClvpDecoderConfig(
|
| 231 |
+
vocab_size=self.vocab_size,
|
| 232 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 233 |
+
max_text_tokens=self.max_text_tokens,
|
| 234 |
+
hidden_size=self.hidden_size,
|
| 235 |
+
num_hidden_layers=self.num_hidden_layers,
|
| 236 |
+
num_attention_heads=self.num_attention_heads,
|
| 237 |
+
bos_token_id=self.bos_token_id,
|
| 238 |
+
eos_token_id=self.eos_token_id,
|
| 239 |
+
relative_attention_num_buckets=self.relative_attention_num_buckets,
|
| 240 |
+
relative_attention_max_distance=self.relative_attention_max_distance,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return decoder_config
|
| 244 |
+
|
| 245 |
+
def prepare_config_and_inputs(self):
|
| 246 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 247 |
+
|
| 248 |
+
input_mask = None
|
| 249 |
+
if self.use_input_mask:
|
| 250 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 251 |
+
|
| 252 |
+
if input_mask is not None:
|
| 253 |
+
batch_size, seq_length = input_mask.shape
|
| 254 |
+
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
| 255 |
+
for batch_idx, start_index in enumerate(rnd_start_indices):
|
| 256 |
+
input_mask[batch_idx, :start_index] = 1
|
| 257 |
+
input_mask[batch_idx, start_index:] = 0
|
| 258 |
+
|
| 259 |
+
decoder_config = self.get_config()
|
| 260 |
+
|
| 261 |
+
return decoder_config, input_ids, input_mask
|
| 262 |
+
|
| 263 |
+
def create_and_check_model(self, config, input_ids, attention_mask):
|
| 264 |
+
model = ClvpForCausalLM(config).to(torch_device).eval()
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
result = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 267 |
+
|
| 268 |
+
self.parent.assertEqual(result[0].shape, (self.batch_size, self.seq_length, self.vocab_size))
|
| 269 |
+
|
| 270 |
+
def prepare_config_and_inputs_for_common(self):
|
| 271 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 272 |
+
config, input_ids, attention_mask = config_and_inputs
|
| 273 |
+
inputs_dict = {
|
| 274 |
+
"input_ids": input_ids.to(torch_device),
|
| 275 |
+
"attention_mask": attention_mask.to(torch_device),
|
| 276 |
+
}
|
| 277 |
+
return config, inputs_dict
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@require_torch
|
| 281 |
+
class ClvpDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 282 |
+
all_model_classes = (ClvpModel, ClvpForCausalLM) if is_torch_available() else ()
|
| 283 |
+
pipeline_model_mapping = {"feature-extraction": ClvpModelForConditionalGeneration} if is_torch_available() else {}
|
| 284 |
+
|
| 285 |
+
test_pruning = False
|
| 286 |
+
|
| 287 |
+
def setUp(self):
|
| 288 |
+
self.model_tester = ClvpDecoderTester(self)
|
| 289 |
+
self.decoder_config_tester = ConfigTester(self, config_class=ClvpDecoderConfig, hidden_size=32)
|
| 290 |
+
|
| 291 |
+
def tearDown(self):
|
| 292 |
+
super().tearDown()
|
| 293 |
+
# clean-up as much as possible GPU memory occupied by PyTorch
|
| 294 |
+
cleanup(torch_device)
|
| 295 |
+
|
| 296 |
+
def test_model(self):
|
| 297 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 298 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 299 |
+
|
| 300 |
+
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
| 301 |
+
if return_labels and model_class == ClvpForCausalLM:
|
| 302 |
+
inputs_dict["labels"] = torch.zeros(
|
| 303 |
+
[self.model_tester.batch_size, self.model_tester.seq_length], device=torch_device
|
| 304 |
+
).long()
|
| 305 |
+
|
| 306 |
+
return inputs_dict
|
| 307 |
+
|
| 308 |
+
def test_training(self):
|
| 309 |
+
# we will only test the ClvpForCausalLM since it outputs loss
|
| 310 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 311 |
+
config.return_dict = True
|
| 312 |
+
|
| 313 |
+
model = ClvpForCausalLM(config)
|
| 314 |
+
model.to(torch_device)
|
| 315 |
+
model.train()
|
| 316 |
+
inputs = self._prepare_for_class(inputs_dict, ClvpForCausalLM, return_labels=True)
|
| 317 |
+
loss = model(**inputs).loss
|
| 318 |
+
loss.backward()
|
| 319 |
+
|
| 320 |
+
def test_training_gradient_checkpointing(self):
|
| 321 |
+
# we will only test the ClvpForCausalLM since it outputs loss
|
| 322 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 323 |
+
config.use_cache = False
|
| 324 |
+
config.return_dict = True
|
| 325 |
+
|
| 326 |
+
model = ClvpForCausalLM(config)
|
| 327 |
+
model.to(torch_device)
|
| 328 |
+
model.gradient_checkpointing_enable()
|
| 329 |
+
model.train()
|
| 330 |
+
inputs = self._prepare_for_class(inputs_dict, ClvpForCausalLM, return_labels=True)
|
| 331 |
+
|
| 332 |
+
loss = model(**inputs).loss
|
| 333 |
+
loss.backward()
|
| 334 |
+
|
| 335 |
+
@unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.")
|
| 336 |
+
def test_generate_continue_from_inputs_embeds(self):
|
| 337 |
+
pass
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class ClvpModelForConditionalGenerationTester:
|
| 341 |
+
def __init__(self, parent, is_training=False):
|
| 342 |
+
self.parent = parent
|
| 343 |
+
self.clvp_encoder_tester = ClvpEncoderTester(parent)
|
| 344 |
+
self.is_training = is_training
|
| 345 |
+
self.batch_size = self.clvp_encoder_tester.batch_size # need bs for batching_equivalence test
|
| 346 |
+
|
| 347 |
+
def get_config(self):
|
| 348 |
+
decoder_config = ClvpDecoderConfig(
|
| 349 |
+
vocab_size=50,
|
| 350 |
+
max_position_embeddings=30,
|
| 351 |
+
max_text_tokens=30,
|
| 352 |
+
hidden_size=128,
|
| 353 |
+
num_hidden_layers=1,
|
| 354 |
+
num_attention_heads=2,
|
| 355 |
+
bos_token_id=97,
|
| 356 |
+
eos_token_id=98,
|
| 357 |
+
relative_attention_num_buckets=4,
|
| 358 |
+
relative_attention_max_distance=16,
|
| 359 |
+
)
|
| 360 |
+
text_config = self.clvp_encoder_tester.get_config()
|
| 361 |
+
speech_config = self.clvp_encoder_tester.get_config()
|
| 362 |
+
speech_config.vocab_size = 300
|
| 363 |
+
|
| 364 |
+
return ClvpConfig.from_sub_model_configs(
|
| 365 |
+
text_config,
|
| 366 |
+
speech_config,
|
| 367 |
+
decoder_config,
|
| 368 |
+
projection_dim=16,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
def prepare_config_and_inputs(self):
|
| 372 |
+
_, input_ids, attention_mask = self.clvp_encoder_tester.prepare_config_and_inputs()
|
| 373 |
+
|
| 374 |
+
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 375 |
+
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
|
| 376 |
+
_, audio, sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
|
| 377 |
+
|
| 378 |
+
feature_extractor = ClvpFeatureExtractor()
|
| 379 |
+
input_features = feature_extractor(raw_speech=audio, sampling_rate=sr, return_tensors="pt")[
|
| 380 |
+
"input_features"
|
| 381 |
+
].to(torch_device)
|
| 382 |
+
|
| 383 |
+
config = self.get_config()
|
| 384 |
+
|
| 385 |
+
return config, input_ids, attention_mask, input_features
|
| 386 |
+
|
| 387 |
+
def create_and_check_model(self, config, input_ids, attention_mask, input_features):
|
| 388 |
+
model = ClvpModelForConditionalGeneration(config).to(torch_device).eval()
|
| 389 |
+
with torch.no_grad():
|
| 390 |
+
result = model(input_ids=input_ids, input_features=input_features, attention_mask=attention_mask)
|
| 391 |
+
|
| 392 |
+
self.parent.assertEqual(result.logits_per_speech.shape, (2, self.clvp_encoder_tester.batch_size))
|
| 393 |
+
self.parent.assertEqual(result.logits_per_text.shape, (self.clvp_encoder_tester.batch_size, 2))
|
| 394 |
+
|
| 395 |
+
def prepare_config_and_inputs_for_common(self):
|
| 396 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 397 |
+
config, input_ids, attention_mask, input_features = config_and_inputs
|
| 398 |
+
inputs_dict = {
|
| 399 |
+
"input_ids": input_ids.to(torch_device),
|
| 400 |
+
"attention_mask": attention_mask.to(torch_device),
|
| 401 |
+
"input_features": input_features.to(torch_device),
|
| 402 |
+
"return_loss": False,
|
| 403 |
+
}
|
| 404 |
+
return config, inputs_dict
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@require_torch
|
| 408 |
+
class ClvpModelForConditionalGenerationTest(ModelTesterMixin, unittest.TestCase):
|
| 409 |
+
all_model_classes = (ClvpModelForConditionalGeneration,) if is_torch_available() else ()
|
| 410 |
+
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante
|
| 411 |
+
all_generative_model_classes = ()
|
| 412 |
+
|
| 413 |
+
test_head_masking = False
|
| 414 |
+
test_pruning = False
|
| 415 |
+
test_resize_embeddings = False
|
| 416 |
+
test_attention_outputs = False
|
| 417 |
+
test_torchscript = False
|
| 418 |
+
|
| 419 |
+
def setUp(self):
|
| 420 |
+
self.model_tester = ClvpModelForConditionalGenerationTester(self)
|
| 421 |
+
common_properties = ["projection_dim", "logit_scale_init_value"]
|
| 422 |
+
self.clvp_config_tester = ConfigTester(
|
| 423 |
+
self, config_class=ClvpConfig, has_text_modality=False, common_properties=common_properties, hidden_size=32
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
def test_config(self):
|
| 427 |
+
self.clvp_config_tester.run_common_tests()
|
| 428 |
+
|
| 429 |
+
def tearDown(self):
|
| 430 |
+
super().tearDown()
|
| 431 |
+
# clean-up as much as possible GPU memory occupied by PyTorch
|
| 432 |
+
cleanup(torch_device)
|
| 433 |
+
|
| 434 |
+
def test_model(self):
|
| 435 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 436 |
+
self.model_tester.create_and_check_model(*config_and_inputs)
|
| 437 |
+
|
| 438 |
+
def test_hidden_states_output(self):
|
| 439 |
+
def check_hidden_states_output(inputs_dict, config, model_class):
|
| 440 |
+
model = model_class(config)
|
| 441 |
+
model.to(torch_device)
|
| 442 |
+
model.eval()
|
| 443 |
+
|
| 444 |
+
with torch.no_grad():
|
| 445 |
+
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
| 446 |
+
|
| 447 |
+
# check for decoder model, text encoder model and speech encoder model hidden states
|
| 448 |
+
decoder_hidden_states = outputs.decoder_hidden_states
|
| 449 |
+
text_encoder_hidden_states = outputs.text_encoder_hidden_states
|
| 450 |
+
speech_encoder_hidden_states = outputs.speech_encoder_hidden_states
|
| 451 |
+
|
| 452 |
+
# check length of the hidden states
|
| 453 |
+
expected_decoder_num_layers = config.decoder_config.num_hidden_layers + 1
|
| 454 |
+
self.assertEqual(len(decoder_hidden_states), expected_decoder_num_layers)
|
| 455 |
+
|
| 456 |
+
expected_speech_encoder_num_layers = config.text_config.num_hidden_layers + 1
|
| 457 |
+
self.assertEqual(len(text_encoder_hidden_states), expected_speech_encoder_num_layers)
|
| 458 |
+
|
| 459 |
+
expected_text_encoder_num_layers = config.speech_config.num_hidden_layers + 1
|
| 460 |
+
self.assertEqual(len(speech_encoder_hidden_states), expected_text_encoder_num_layers)
|
| 461 |
+
|
| 462 |
+
# check shapes of each hidden state
|
| 463 |
+
|
| 464 |
+
# for the decoder model we will only test the dimension because the ClvpConditioningEncoder could increase
|
| 465 |
+
# the sequence lengths.
|
| 466 |
+
self.assertEqual(decoder_hidden_states[0].shape[-1], config.decoder_config.hidden_size)
|
| 467 |
+
|
| 468 |
+
# the testing for text encoder stays standard because we just pass the text tokens here.
|
| 469 |
+
self.assertListEqual(
|
| 470 |
+
list(text_encoder_hidden_states[0].shape[-2:]),
|
| 471 |
+
[self.model_tester.clvp_encoder_tester.seq_length, config.text_config.hidden_size],
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# for the decoder model we will only test the dimension because the fix_decoder_outputs method could increase
|
| 475 |
+
# the sequence lengths by adding `decoder_fixing_codes` tokens at the end.
|
| 476 |
+
self.assertEqual(speech_encoder_hidden_states[0].shape[-1], config.speech_config.hidden_size)
|
| 477 |
+
|
| 478 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 479 |
+
|
| 480 |
+
for model_class in self.all_model_classes:
|
| 481 |
+
inputs_dict["output_hidden_states"] = True
|
| 482 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 483 |
+
|
| 484 |
+
# check that output_hidden_states also work using config
|
| 485 |
+
del inputs_dict["output_hidden_states"]
|
| 486 |
+
config.output_hidden_states = True
|
| 487 |
+
|
| 488 |
+
check_hidden_states_output(inputs_dict, config, model_class)
|
| 489 |
+
|
| 490 |
+
@unittest.skip(reason="Retain_grad is tested in individual model tests")
|
| 491 |
+
def test_retain_grad_hidden_states_attentions(self):
|
| 492 |
+
pass
|
| 493 |
+
|
| 494 |
+
@unittest.skip(reason="ClvpModelForConditionalGeneration does not have get_input_embeddings")
|
| 495 |
+
def test_inputs_embeds(self):
|
| 496 |
+
pass
|
| 497 |
+
|
| 498 |
+
@unittest.skip(reason="ClvpModelForConditionalGeneration does not have get_input_embeddings")
|
| 499 |
+
def test_model_get_set_embeddings(self):
|
| 500 |
+
pass
|
| 501 |
+
|
| 502 |
+
# override as the `logit_scale` parameter initialization is different for Clvp
|
| 503 |
+
def test_initialization(self):
|
| 504 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 505 |
+
|
| 506 |
+
configs_no_init = _config_zero_init(config)
|
| 507 |
+
for model_class in self.all_model_classes:
|
| 508 |
+
model = model_class(config=configs_no_init)
|
| 509 |
+
for name, param in model.named_parameters():
|
| 510 |
+
if param.requires_grad:
|
| 511 |
+
# check if `logit_scale` is initialized as per the original implementation
|
| 512 |
+
if name == "logit_scale":
|
| 513 |
+
expected_value = np.log(1 / 0.07)
|
| 514 |
+
returned_value = param.data.item()
|
| 515 |
+
|
| 516 |
+
self.assertAlmostEqual(
|
| 517 |
+
returned_value,
|
| 518 |
+
expected_value,
|
| 519 |
+
delta=1e-3,
|
| 520 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
expected_range = [0.0, 1.0]
|
| 524 |
+
returned_range = ((param.data.mean() * 1e9).round() / 1e9).item()
|
| 525 |
+
|
| 526 |
+
self.assertIn(
|
| 527 |
+
returned_range,
|
| 528 |
+
expected_range,
|
| 529 |
+
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
def test_load_speech_text_decoder_config(self):
|
| 533 |
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
| 534 |
+
|
| 535 |
+
# Save ClvpConfig and check if we can load ClvpEncoderConfig from it
|
| 536 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 537 |
+
config.save_pretrained(tmp_dir_name)
|
| 538 |
+
encoder_config = ClvpEncoderConfig.from_pretrained(tmp_dir_name)
|
| 539 |
+
self.assertDictEqual(config.text_config.to_dict(), encoder_config.to_dict())
|
| 540 |
+
|
| 541 |
+
# Save ClvpConfig and check if we can load ClvpDecoderConfig from it
|
| 542 |
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
| 543 |
+
config.save_pretrained(tmp_dir_name)
|
| 544 |
+
decoder_config = ClvpDecoderConfig.from_pretrained(tmp_dir_name)
|
| 545 |
+
self.assertDictEqual(config.decoder_config.to_dict(), decoder_config.to_dict())
|
| 546 |
+
|
| 547 |
+
@slow
|
| 548 |
+
def test_model_from_pretrained(self):
|
| 549 |
+
model_name = "susnato/clvp_dev"
|
| 550 |
+
model = ClvpModelForConditionalGeneration.from_pretrained(model_name)
|
| 551 |
+
self.assertIsNotNone(model)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# Since Clvp has a lot of different models connected with each other it's better to test each of them individually along
|
| 555 |
+
# with a test_full_model_integration. If the model breaks in future, it could be of a great help to identify the broken part.
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@slow
|
| 559 |
+
@require_torch
|
| 560 |
+
class ClvpIntegrationTest(unittest.TestCase):
|
| 561 |
+
def setUp(self):
|
| 562 |
+
self.text = "This is an example text."
|
| 563 |
+
ds = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
| 564 |
+
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=22050))
|
| 565 |
+
_, self.speech_samples, self.sr = ds.sort("id").select(range(1))[:1]["audio"][0].values()
|
| 566 |
+
|
| 567 |
+
self.model = ClvpModelForConditionalGeneration.from_pretrained("susnato/clvp_dev").to(torch_device)
|
| 568 |
+
self.model.eval()
|
| 569 |
+
tokenizer = ClvpTokenizer.from_pretrained("susnato/clvp_dev")
|
| 570 |
+
feature_extractor = ClvpFeatureExtractor.from_pretrained("susnato/clvp_dev")
|
| 571 |
+
|
| 572 |
+
tokenizer_output = tokenizer(self.text, return_tensors="pt")
|
| 573 |
+
self.text_tokens = tokenizer_output["input_ids"].to(torch_device)
|
| 574 |
+
self.input_features = feature_extractor(
|
| 575 |
+
raw_speech=self.speech_samples, sampling_rate=self.sr, return_tensors="pt"
|
| 576 |
+
)["input_features"].to(torch_device)
|
| 577 |
+
|
| 578 |
+
def tearDown(self):
|
| 579 |
+
super().tearDown()
|
| 580 |
+
# clean-up as much as possible GPU memory occupied by PyTorch
|
| 581 |
+
cleanup(torch_device, gc_collect=True)
|
| 582 |
+
|
| 583 |
+
def test_conditional_encoder(self):
|
| 584 |
+
with torch.no_grad():
|
| 585 |
+
conditioning_encoder_outputs = self.model.conditioning_encoder(
|
| 586 |
+
input_features=self.input_features, input_ids=self.text_tokens
|
| 587 |
+
).to("cpu")
|
| 588 |
+
|
| 589 |
+
self.assertEqual(
|
| 590 |
+
conditioning_encoder_outputs.shape,
|
| 591 |
+
torch.Size((self.input_features.shape[0], 18, self.model.config.decoder_config.hidden_size)),
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
EXPECTED_OUTPUTS = torch.tensor(
|
| 595 |
+
[[-0.8582, 0.5228, 1.9944], [-0.0465, -1.1017, -0.0093], [-0.0466, -0.6030, -0.1280]]
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
torch.testing.assert_close(conditioning_encoder_outputs[0, :3, :3], EXPECTED_OUTPUTS, rtol=1e-4, atol=1e-4)
|
| 599 |
+
|
| 600 |
+
def test_decoder_model_generate(self):
|
| 601 |
+
autoregressive_model_output = self.model.speech_decoder_model.generate(input_ids=self.text_tokens).cpu()
|
| 602 |
+
|
| 603 |
+
EXPECTED_OUTPUTS = torch.tensor([[147, 2, 54, 2, 43, 2, 169, 122, 29, 64, 2, 136, 37, 33, 9, 8193]])
|
| 604 |
+
|
| 605 |
+
torch.testing.assert_close(autoregressive_model_output, EXPECTED_OUTPUTS)
|
| 606 |
+
|
| 607 |
+
def test_text_and_speech_encoder_models(self):
|
| 608 |
+
# check for text embeds
|
| 609 |
+
text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
|
| 610 |
+
|
| 611 |
+
# fmt: off
|
| 612 |
+
EXPECTED_TEXT_EMBEDS = torch.tensor([1.4798, -2.0005, 2.3902, -0.5042, 1.6401, -2.4135, -1.4800, 3.0118, -2.4422, 1.3266, 2.2339, 1.4761, -4.8983, -1.3592, 6.0251, 6.7364, 2.2576, 3.7229, -10.0436, 4.6676])
|
| 613 |
+
# fmt: on
|
| 614 |
+
|
| 615 |
+
torch.testing.assert_close(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, rtol=1e-4, atol=1e-4)
|
| 616 |
+
|
| 617 |
+
# check for speech embeds
|
| 618 |
+
speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
|
| 619 |
+
|
| 620 |
+
# fmt: off
|
| 621 |
+
EXPECTED_SPEECH_EMBEDS = torch.tensor([3.1202, -3.1183, -1.4264, -6.1339, 1.8885, -0.1983, 0.9461, -1.7414, 0.3320, -3.8400, -1.5715, 1.5096, -1.7576, 0.2387, 4.9758, 5.8450, -6.2534, 2.8587, -5.5816, 4.7821])
|
| 622 |
+
# fmt: on
|
| 623 |
+
|
| 624 |
+
torch.testing.assert_close(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, rtol=1e-4, atol=1e-4)
|
| 625 |
+
|
| 626 |
+
def test_full_model_integration(self):
|
| 627 |
+
full_model_output = self.model.generate(
|
| 628 |
+
input_ids=self.text_tokens,
|
| 629 |
+
input_features=self.input_features,
|
| 630 |
+
do_sample=False,
|
| 631 |
+
num_beams=4,
|
| 632 |
+
num_return_sequences=4,
|
| 633 |
+
max_new_tokens=10,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
EXPECTED_SPEECH_IDS = torch.tensor([[1953, 1080, 612], [1953, 612, 493], [1953, 612, 716]])
|
| 637 |
+
EXPECTED_SIMILARITY_SCORES = torch.tensor([[14.7660, 14.4569, 13.6472, 13.5683]])
|
| 638 |
+
|
| 639 |
+
torch.testing.assert_close(full_model_output.speech_ids.cpu()[-3:, -3:], EXPECTED_SPEECH_IDS)
|
| 640 |
+
torch.testing.assert_close(full_model_output.logits_per_text.cpu(), EXPECTED_SIMILARITY_SCORES)
|
docs/transformers/tests/models/clvp/test_processor_clvp.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
from transformers import ClvpFeatureExtractor, ClvpProcessor, ClvpTokenizer
|
| 22 |
+
from transformers.testing_utils import require_torch
|
| 23 |
+
|
| 24 |
+
from .test_feature_extraction_clvp import floats_list
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@require_torch
|
| 28 |
+
class ClvpProcessorTest(unittest.TestCase):
|
| 29 |
+
def setUp(self):
|
| 30 |
+
self.checkpoint = "susnato/clvp_dev"
|
| 31 |
+
self.tmpdirname = tempfile.mkdtemp()
|
| 32 |
+
|
| 33 |
+
def tearDown(self):
|
| 34 |
+
super().tearDown()
|
| 35 |
+
shutil.rmtree(self.tmpdirname)
|
| 36 |
+
gc.collect()
|
| 37 |
+
|
| 38 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.get_tokenizer with Whisper->Clvp
|
| 39 |
+
def get_tokenizer(self, **kwargs):
|
| 40 |
+
return ClvpTokenizer.from_pretrained(self.checkpoint, **kwargs)
|
| 41 |
+
|
| 42 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.get_feature_extractor with Whisper->Clvp
|
| 43 |
+
def get_feature_extractor(self, **kwargs):
|
| 44 |
+
return ClvpFeatureExtractor.from_pretrained(self.checkpoint, **kwargs)
|
| 45 |
+
|
| 46 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.test_save_load_pretrained_default with Whisper->Clvp
|
| 47 |
+
def test_save_load_pretrained_default(self):
|
| 48 |
+
tokenizer = self.get_tokenizer()
|
| 49 |
+
feature_extractor = self.get_feature_extractor()
|
| 50 |
+
|
| 51 |
+
processor = ClvpProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 52 |
+
|
| 53 |
+
processor.save_pretrained(self.tmpdirname)
|
| 54 |
+
processor = ClvpProcessor.from_pretrained(self.tmpdirname)
|
| 55 |
+
|
| 56 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab())
|
| 57 |
+
self.assertIsInstance(processor.tokenizer, ClvpTokenizer)
|
| 58 |
+
|
| 59 |
+
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor.to_json_string())
|
| 60 |
+
self.assertIsInstance(processor.feature_extractor, ClvpFeatureExtractor)
|
| 61 |
+
|
| 62 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.test_feature_extractor with Whisper->Clvp,processor(raw_speech->processor(raw_speech=raw_speech
|
| 63 |
+
def test_feature_extractor(self):
|
| 64 |
+
feature_extractor = self.get_feature_extractor()
|
| 65 |
+
tokenizer = self.get_tokenizer()
|
| 66 |
+
|
| 67 |
+
processor = ClvpProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 68 |
+
|
| 69 |
+
raw_speech = floats_list((3, 1000))
|
| 70 |
+
|
| 71 |
+
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
| 72 |
+
input_processor = processor(raw_speech=raw_speech, return_tensors="np")
|
| 73 |
+
|
| 74 |
+
for key in input_feat_extract.keys():
|
| 75 |
+
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
| 76 |
+
|
| 77 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.test_tokenizer with Whisper->Clvp
|
| 78 |
+
def test_tokenizer(self):
|
| 79 |
+
feature_extractor = self.get_feature_extractor()
|
| 80 |
+
tokenizer = self.get_tokenizer()
|
| 81 |
+
|
| 82 |
+
processor = ClvpProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 83 |
+
|
| 84 |
+
input_str = "This is a test string"
|
| 85 |
+
|
| 86 |
+
encoded_processor = processor(text=input_str)
|
| 87 |
+
|
| 88 |
+
encoded_tok = tokenizer(input_str)
|
| 89 |
+
|
| 90 |
+
for key in encoded_tok.keys():
|
| 91 |
+
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
| 92 |
+
|
| 93 |
+
# Copied from transformers.tests.models.whisper.test_processor_whisper.WhisperProcessorTest.test_tokenizer_decode with Whisper->Clvp
|
| 94 |
+
def test_tokenizer_decode(self):
|
| 95 |
+
feature_extractor = self.get_feature_extractor()
|
| 96 |
+
tokenizer = self.get_tokenizer()
|
| 97 |
+
|
| 98 |
+
processor = ClvpProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 99 |
+
|
| 100 |
+
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9]]
|
| 101 |
+
|
| 102 |
+
decoded_processor = processor.batch_decode(predicted_ids)
|
| 103 |
+
decoded_tok = tokenizer.batch_decode(predicted_ids)
|
| 104 |
+
|
| 105 |
+
self.assertListEqual(decoded_tok, decoded_processor)
|
| 106 |
+
|
| 107 |
+
def test_save_load_pretrained_additional_features(self):
|
| 108 |
+
processor = ClvpProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
| 109 |
+
processor.save_pretrained(self.tmpdirname)
|
| 110 |
+
|
| 111 |
+
tokenizer_add_kwargs = self.get_tokenizer(pad_token="(PAD)")
|
| 112 |
+
feature_extractor_add_kwargs = self.get_feature_extractor(sampling_rate=16000)
|
| 113 |
+
|
| 114 |
+
processor = ClvpProcessor.from_pretrained(
|
| 115 |
+
self.tmpdirname,
|
| 116 |
+
pad_token="(PAD)",
|
| 117 |
+
sampling_rate=16000,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
|
| 121 |
+
self.assertIsInstance(processor.tokenizer, ClvpTokenizer)
|
| 122 |
+
|
| 123 |
+
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
|
| 124 |
+
self.assertIsInstance(processor.feature_extractor, ClvpFeatureExtractor)
|
| 125 |
+
|
| 126 |
+
def test_model_input_names(self):
|
| 127 |
+
feature_extractor = self.get_feature_extractor()
|
| 128 |
+
tokenizer = self.get_tokenizer()
|
| 129 |
+
|
| 130 |
+
processor = ClvpProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
| 131 |
+
|
| 132 |
+
self.assertListEqual(
|
| 133 |
+
sorted(processor.model_input_names),
|
| 134 |
+
sorted(set(feature_extractor.model_input_names + tokenizer.model_input_names)),
|
| 135 |
+
msg="`processor` and `feature_extractor` model input names do not match",
|
| 136 |
+
)
|
docs/transformers/tests/models/clvp/test_tokenization_clvp.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import unittest
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
|
| 21 |
+
from transformers import ClvpTokenizer
|
| 22 |
+
|
| 23 |
+
from ...test_tokenization_common import TokenizerTesterMixin, slow, use_cache_if_possible
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ClvpTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 27 |
+
from_pretrained_id = "susnato/clvp_dev"
|
| 28 |
+
tokenizer_class = ClvpTokenizer
|
| 29 |
+
test_rust_tokenizer = False
|
| 30 |
+
from_pretrained_kwargs = {"add_prefix_space": True}
|
| 31 |
+
test_seq2seq = False
|
| 32 |
+
test_sentencepiece_ignore_case = True
|
| 33 |
+
|
| 34 |
+
@classmethod
|
| 35 |
+
def setUpClass(cls):
|
| 36 |
+
super().setUpClass()
|
| 37 |
+
|
| 38 |
+
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
| 39 |
+
vocab = [
|
| 40 |
+
"l",
|
| 41 |
+
"o",
|
| 42 |
+
"w",
|
| 43 |
+
"e",
|
| 44 |
+
"r",
|
| 45 |
+
"s",
|
| 46 |
+
"t",
|
| 47 |
+
"i",
|
| 48 |
+
"d",
|
| 49 |
+
"n",
|
| 50 |
+
"\u0120",
|
| 51 |
+
"\u0120l",
|
| 52 |
+
"\u0120n",
|
| 53 |
+
"\u0120lo",
|
| 54 |
+
"\u0120low",
|
| 55 |
+
"er",
|
| 56 |
+
"\u0120lowest",
|
| 57 |
+
"\u0120newer",
|
| 58 |
+
"\u0120wider",
|
| 59 |
+
"<unk>",
|
| 60 |
+
"<|endoftext|>",
|
| 61 |
+
"[SPACE]",
|
| 62 |
+
]
|
| 63 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
| 64 |
+
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
| 65 |
+
cls.special_tokens_map = {"unk_token": "<unk>"}
|
| 66 |
+
|
| 67 |
+
cls.vocab_file = os.path.join(cls.tmpdirname, "vocab.json")
|
| 68 |
+
cls.merges_file = os.path.join(cls.tmpdirname, "merges.txt")
|
| 69 |
+
with open(cls.vocab_file, "w", encoding="utf-8") as fp:
|
| 70 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
| 71 |
+
with open(cls.merges_file, "w", encoding="utf-8") as fp:
|
| 72 |
+
fp.write("\n".join(merges))
|
| 73 |
+
|
| 74 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.get_tokenizer with GPT2->Clvp
|
| 75 |
+
@classmethod
|
| 76 |
+
@use_cache_if_possible
|
| 77 |
+
@lru_cache(maxsize=64)
|
| 78 |
+
def get_tokenizer(cls, pretrained_name=None, **kwargs):
|
| 79 |
+
kwargs.update(cls.special_tokens_map)
|
| 80 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 81 |
+
return ClvpTokenizer.from_pretrained(pretrained_name, **kwargs)
|
| 82 |
+
|
| 83 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.get_input_output_texts
|
| 84 |
+
def get_input_output_texts(self, tokenizer):
|
| 85 |
+
input_text = "lower newer"
|
| 86 |
+
output_text = "lower[SPACE]newer"
|
| 87 |
+
return input_text, output_text
|
| 88 |
+
|
| 89 |
+
# Copied from transformers.tests.models.layoutxlm.test_tokenization_layoutxlm.LayoutXLMTokenizationTest.test_add_special_tokens
|
| 90 |
+
def test_add_special_tokens(self):
|
| 91 |
+
tokenizers: list[ClvpTokenizer] = self.get_tokenizers(do_lower_case=False)
|
| 92 |
+
for tokenizer in tokenizers:
|
| 93 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 94 |
+
special_token = "[SPECIAL_TOKEN]"
|
| 95 |
+
special_token_box = [1000, 1000, 1000, 1000]
|
| 96 |
+
|
| 97 |
+
tokenizer.add_special_tokens({"cls_token": special_token})
|
| 98 |
+
encoded_special_token = tokenizer.encode(
|
| 99 |
+
[special_token], boxes=[special_token_box], add_special_tokens=False
|
| 100 |
+
)
|
| 101 |
+
self.assertEqual(len(encoded_special_token), 1)
|
| 102 |
+
|
| 103 |
+
decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True)
|
| 104 |
+
self.assertTrue(special_token not in decoded)
|
| 105 |
+
|
| 106 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.test_rust_and_python_full_tokenizers
|
| 107 |
+
def test_rust_and_python_full_tokenizers(self):
|
| 108 |
+
if not self.test_rust_tokenizer:
|
| 109 |
+
self.skipTest(reason="test_rust_tokenizer is set to False")
|
| 110 |
+
|
| 111 |
+
tokenizer = self.get_tokenizer()
|
| 112 |
+
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
| 113 |
+
|
| 114 |
+
sequence = "lower newer"
|
| 115 |
+
|
| 116 |
+
# Testing tokenization
|
| 117 |
+
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
| 118 |
+
rust_tokens = rust_tokenizer.tokenize(sequence)
|
| 119 |
+
self.assertListEqual(tokens, rust_tokens)
|
| 120 |
+
|
| 121 |
+
# Testing conversion to ids without special tokens
|
| 122 |
+
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
| 123 |
+
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
| 124 |
+
self.assertListEqual(ids, rust_ids)
|
| 125 |
+
|
| 126 |
+
# Testing conversion to ids with special tokens
|
| 127 |
+
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
| 128 |
+
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
| 129 |
+
rust_ids = rust_tokenizer.encode(sequence)
|
| 130 |
+
self.assertListEqual(ids, rust_ids)
|
| 131 |
+
|
| 132 |
+
# Testing the unknown token
|
| 133 |
+
input_tokens = tokens + [rust_tokenizer.unk_token]
|
| 134 |
+
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
| 135 |
+
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
| 136 |
+
|
| 137 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.test_padding
|
| 138 |
+
def test_padding(self, max_length=15):
|
| 139 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 140 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 141 |
+
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs)
|
| 142 |
+
|
| 143 |
+
# Simple input
|
| 144 |
+
s = "This is a simple input"
|
| 145 |
+
s2 = ["This is a simple input 1", "This is a simple input 2"]
|
| 146 |
+
p = ("This is a simple input", "This is a pair")
|
| 147 |
+
p2 = [
|
| 148 |
+
("This is a simple input 1", "This is a simple input 2"),
|
| 149 |
+
("This is a simple pair 1", "This is a simple pair 2"),
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
# Simple input tests
|
| 153 |
+
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
| 154 |
+
|
| 155 |
+
# Simple input
|
| 156 |
+
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
| 157 |
+
|
| 158 |
+
# Simple input
|
| 159 |
+
self.assertRaises(
|
| 160 |
+
ValueError,
|
| 161 |
+
tokenizer_r.batch_encode_plus,
|
| 162 |
+
s2,
|
| 163 |
+
max_length=max_length,
|
| 164 |
+
padding="max_length",
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Pair input
|
| 168 |
+
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
| 169 |
+
|
| 170 |
+
# Pair input
|
| 171 |
+
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
| 172 |
+
|
| 173 |
+
# Pair input
|
| 174 |
+
self.assertRaises(
|
| 175 |
+
ValueError,
|
| 176 |
+
tokenizer_r.batch_encode_plus,
|
| 177 |
+
p2,
|
| 178 |
+
max_length=max_length,
|
| 179 |
+
padding="max_length",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.test_padding_if_pad_token_set_slow
|
| 183 |
+
def test_padding_if_pad_token_set_slow(self):
|
| 184 |
+
tokenizer = ClvpTokenizer.from_pretrained(self.tmpdirname, pad_token="<pad>")
|
| 185 |
+
|
| 186 |
+
# Simple input
|
| 187 |
+
s = "This is a simple input"
|
| 188 |
+
s2 = ["This is a simple input looooooooong", "This is a simple input"]
|
| 189 |
+
p = ("This is a simple input", "This is a pair")
|
| 190 |
+
p2 = [
|
| 191 |
+
("This is a simple input loooooong", "This is a simple input"),
|
| 192 |
+
("This is a simple pair loooooong", "This is a simple pair"),
|
| 193 |
+
]
|
| 194 |
+
|
| 195 |
+
pad_token_id = tokenizer.pad_token_id
|
| 196 |
+
|
| 197 |
+
out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np")
|
| 198 |
+
out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np")
|
| 199 |
+
out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np")
|
| 200 |
+
out_p2 = tokenizer(p2, padding=True, truncate=True, return_tensors="np")
|
| 201 |
+
|
| 202 |
+
# s
|
| 203 |
+
# test single string max_length padding
|
| 204 |
+
self.assertEqual(out_s["input_ids"].shape[-1], 30)
|
| 205 |
+
self.assertTrue(pad_token_id in out_s["input_ids"])
|
| 206 |
+
self.assertTrue(0 in out_s["attention_mask"])
|
| 207 |
+
|
| 208 |
+
# s2
|
| 209 |
+
# test automatic padding
|
| 210 |
+
self.assertEqual(out_s2["input_ids"].shape[-1], 33)
|
| 211 |
+
# long slice doesn't have padding
|
| 212 |
+
self.assertFalse(pad_token_id in out_s2["input_ids"][0])
|
| 213 |
+
self.assertFalse(0 in out_s2["attention_mask"][0])
|
| 214 |
+
# short slice does have padding
|
| 215 |
+
self.assertTrue(pad_token_id in out_s2["input_ids"][1])
|
| 216 |
+
self.assertTrue(0 in out_s2["attention_mask"][1])
|
| 217 |
+
|
| 218 |
+
# p
|
| 219 |
+
# test single pair max_length padding
|
| 220 |
+
self.assertEqual(out_p["input_ids"].shape[-1], 60)
|
| 221 |
+
self.assertTrue(pad_token_id in out_p["input_ids"])
|
| 222 |
+
self.assertTrue(0 in out_p["attention_mask"])
|
| 223 |
+
|
| 224 |
+
# p2
|
| 225 |
+
# test automatic padding pair
|
| 226 |
+
self.assertEqual(out_p2["input_ids"].shape[-1], 52)
|
| 227 |
+
# long slice pair doesn't have padding
|
| 228 |
+
self.assertFalse(pad_token_id in out_p2["input_ids"][0])
|
| 229 |
+
self.assertFalse(0 in out_p2["attention_mask"][0])
|
| 230 |
+
# short slice pair does have padding
|
| 231 |
+
self.assertTrue(pad_token_id in out_p2["input_ids"][1])
|
| 232 |
+
self.assertTrue(0 in out_p2["attention_mask"][1])
|
| 233 |
+
|
| 234 |
+
# Copied from transformers.tests.models.gpt2.test_tokenization_gpt2.GPT2TokenizationTest.test_special_tokens_mask_input_pairs_and_bos_token
|
| 235 |
+
def test_special_tokens_mask_input_pairs_and_bos_token(self):
|
| 236 |
+
# TODO: change to self.get_tokenizers() when the fast version is implemented
|
| 237 |
+
tokenizers = [self.get_tokenizer(do_lower_case=False, add_bos_token=True)]
|
| 238 |
+
for tokenizer in tokenizers:
|
| 239 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 240 |
+
sequence_0 = "Encode this."
|
| 241 |
+
sequence_1 = "This one too please."
|
| 242 |
+
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
| 243 |
+
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
|
| 244 |
+
encoded_sequence_dict = tokenizer.encode_plus(
|
| 245 |
+
sequence_0,
|
| 246 |
+
sequence_1,
|
| 247 |
+
add_special_tokens=True,
|
| 248 |
+
return_special_tokens_mask=True,
|
| 249 |
+
)
|
| 250 |
+
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
| 251 |
+
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
| 252 |
+
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
| 253 |
+
|
| 254 |
+
filtered_sequence = [
|
| 255 |
+
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
|
| 256 |
+
]
|
| 257 |
+
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
| 258 |
+
self.assertEqual(encoded_sequence, filtered_sequence)
|
| 259 |
+
|
| 260 |
+
def test_token_type_ids(self):
|
| 261 |
+
tokenizer = self.get_tokenizer()
|
| 262 |
+
seq_0 = "Test this method."
|
| 263 |
+
|
| 264 |
+
# We want to have sequence 0 and sequence 1 are tagged
|
| 265 |
+
# respectively with 0 and 1 token_ids
|
| 266 |
+
# (regardless of whether the model use token type ids)
|
| 267 |
+
# We use this assumption in the QA pipeline among other place
|
| 268 |
+
output = tokenizer(seq_0, return_token_type_ids=True, add_special_tokens=True)
|
| 269 |
+
self.assertIn(0, output["token_type_ids"])
|
| 270 |
+
|
| 271 |
+
def test_full_tokenizer(self):
|
| 272 |
+
tokenizer = ClvpTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
| 273 |
+
text = "lower newer"
|
| 274 |
+
bpe_tokens = ["l", "o", "w", "er", "[SPACE]", "n", "e", "w", "er"]
|
| 275 |
+
tokens = tokenizer.tokenize(text, add_prefix_space=False)
|
| 276 |
+
self.assertListEqual(tokens, bpe_tokens)
|
| 277 |
+
|
| 278 |
+
input_tokens = tokens + [tokenizer.unk_token]
|
| 279 |
+
input_bpe_tokens = [0, 1, 2, 15, 21, 9, 3, 2, 15, 19]
|
| 280 |
+
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
| 281 |
+
|
| 282 |
+
@slow
|
| 283 |
+
def test_outputs_with_numbers(self):
|
| 284 |
+
text = "hello and this is an example text and I have $1000. my lucky number is 12345."
|
| 285 |
+
tokenizer = ClvpTokenizer.from_pretrained("susnato/clvp_dev")
|
| 286 |
+
|
| 287 |
+
# fmt: off
|
| 288 |
+
EXPECTED_OUTPUT = [62, 84, 28, 2, 53, 2,147, 2, 54, 2, 43, 2, 169, 122, 29, 64, 2, 136, 37, 33, 2, 53, 2, 22,
|
| 289 |
+
2, 148, 2, 110, 2, 40, 206, 53, 2, 134, 84, 59, 32, 9, 2, 125, 2, 25, 34, 197, 38, 2, 27,
|
| 290 |
+
231, 15, 44, 2, 54, 2, 33, 100, 25, 76, 2, 40, 206, 53, 7, 2, 40, 46, 18, 2, 21, 97, 17,
|
| 291 |
+
219, 2, 87, 210, 8, 19, 22, 76, 9,
|
| 292 |
+
]
|
| 293 |
+
# fmt: on
|
| 294 |
+
|
| 295 |
+
self.assertListEqual(tokenizer.encode(text, add_special_tokens=False), EXPECTED_OUTPUT)
|
| 296 |
+
|
| 297 |
+
@slow
|
| 298 |
+
def test_tokenizer_integration(self):
|
| 299 |
+
sequences = [
|
| 300 |
+
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
|
| 301 |
+
"general-purpose architectures (BERT, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
|
| 302 |
+
"Language Understanding (NLU) and Natural Language Generation (NLG) with over multiple pretrained "
|
| 303 |
+
"models and deep interoperability between Jax, PyTorch and TensorFlow.",
|
| 304 |
+
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
|
| 305 |
+
"conditioning on both left and right context in all layers.",
|
| 306 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 307 |
+
]
|
| 308 |
+
|
| 309 |
+
# fmt: off
|
| 310 |
+
expected_encoding = {'input_ids': [[144, 43, 32, 87, 26, 173, 2, 5, 87, 26, 44, 70, 2, 209, 27, 2, 55, 2, 29, 38, 51, 31, 71, 8, 144, 43, 32, 87, 26, 173, 2, 53, 2, 29, 38, 51, 31, 71, 8, 29, 46, 144, 137, 49, 8, 15, 44, 33, 6, 2, 187, 35, 83, 61, 2, 20, 50, 44, 56, 8, 29, 121, 139, 66, 2, 59, 71, 60, 18, 16, 33, 34, 175, 2, 5, 15, 44, 33, 7, 2, 89, 15, 44, 33, 14, 7, 2, 37, 25, 26, 7, 2, 17, 54, 78, 25, 15, 44, 33, 7, 2, 37, 25, 111, 33, 9, 9, 9, 6, 2, 87, 2, 27, 48, 121, 56, 2, 25, 43, 20, 34, 14, 112, 2, 97, 234, 63, 53, 52, 2, 5, 27, 25, 34, 6, 2, 53, 2, 27, 48, 121, 56, 2, 25, 43, 20, 34, 14, 112, 2, 20, 50, 44, 158, 2, 5, 27, 25, 20, 6, 2, 103, 2, 253, 2, 26, 167, 78, 29, 64, 2, 29, 46, 144, 137, 49, 2, 115, 126, 25, 32, 2, 53, 2, 126, 18, 29, 2, 41, 114, 161, 44, 109, 151, 240, 2, 67, 33, 100, 50, 2, 23, 14, 37, 7, 2, 29, 38, 51, 31, 71, 2, 53, 2, 33, 50, 32, 57, 19, 25, 69, 9], [ 15, 44, 33, 2, 54, 2, 17, 61, 22, 20, 27, 49, 2, 51, 2, 29, 46, 8, 144, 137, 2, 126, 18, 29, 2, 15, 83, 22, 46, 16, 181, 56, 2, 46, 29, 175, 86, 158, 32, 2, 154, 2, 97, 25, 14, 67, 25, 49, 2, 136, 37, 33, 2, 185, 2, 23, 28, 41, 33, 70, 2, 135, 17, 60, 107, 52, 2, 47, 2, 165, 40, 2, 64, 19, 33, 2, 53, 2, 101, 104, 2, 135, 136, 37, 33, 2, 41, 2, 108, 2, 25, 88, 173, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 42, 2, 194, 91, 24, 2, 243, 190, 2, 182, 37, 2, 23, 231, 29, 32, 2, 253, 2, 42, 2, 25, 14, 39, 38, 2, 134, 20, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # noqa: E501
|
| 311 |
+
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], # noqa: E501
|
| 312 |
+
}
|
| 313 |
+
# fmt: on
|
| 314 |
+
|
| 315 |
+
self.tokenizer_integration_test_util(
|
| 316 |
+
sequences=sequences, expected_encoding=expected_encoding, model_name="susnato/clvp_dev", padding=True
|
| 317 |
+
)
|
docs/transformers/tests/models/code_llama/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/code_llama/test_tokenization_code_llama.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import pickle
|
| 17 |
+
import shutil
|
| 18 |
+
import tempfile
|
| 19 |
+
import unittest
|
| 20 |
+
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
|
| 23 |
+
from transformers import (
|
| 24 |
+
SPIECE_UNDERLINE,
|
| 25 |
+
AddedToken,
|
| 26 |
+
CodeLlamaTokenizer,
|
| 27 |
+
CodeLlamaTokenizerFast,
|
| 28 |
+
)
|
| 29 |
+
from transformers.convert_slow_tokenizer import convert_slow_tokenizer
|
| 30 |
+
from transformers.testing_utils import (
|
| 31 |
+
get_tests_dir,
|
| 32 |
+
nested_simplify,
|
| 33 |
+
require_sentencepiece,
|
| 34 |
+
require_tokenizers,
|
| 35 |
+
require_torch,
|
| 36 |
+
slow,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from ...test_tokenization_common import TokenizerTesterMixin
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@require_sentencepiece
|
| 46 |
+
@require_tokenizers
|
| 47 |
+
class CodeLlamaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 48 |
+
from_pretrained_id = "hf-internal-testing/llama-code-tokenizer"
|
| 49 |
+
tokenizer_class = CodeLlamaTokenizer
|
| 50 |
+
rust_tokenizer_class = CodeLlamaTokenizerFast
|
| 51 |
+
test_rust_tokenizer = False
|
| 52 |
+
test_sentencepiece = True
|
| 53 |
+
from_pretrained_kwargs = {}
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def setUpClass(cls):
|
| 57 |
+
super().setUpClass()
|
| 58 |
+
|
| 59 |
+
# We have a SentencePiece fixture for testing
|
| 60 |
+
tokenizer = CodeLlamaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
| 61 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 62 |
+
tokenizer.save_pretrained(cls.tmpdirname)
|
| 63 |
+
|
| 64 |
+
def get_tokenizers(cls, **kwargs):
|
| 65 |
+
kwargs.update({"pad_token": "<PAD>"})
|
| 66 |
+
return super().get_tokenizers(**kwargs)
|
| 67 |
+
|
| 68 |
+
def test_no_infilling_init(self):
|
| 69 |
+
tokenizer = CodeLlamaTokenizer(SAMPLE_VOCAB, prefix_token=None, keep_accents=True)
|
| 70 |
+
with self.assertRaises(ValueError):
|
| 71 |
+
tokenizer.tokenize("This is <FILL_ME> prefix")
|
| 72 |
+
|
| 73 |
+
def test_full_tokenizer(self):
|
| 74 |
+
tokenizer = CodeLlamaTokenizer(SAMPLE_VOCAB, keep_accents=True)
|
| 75 |
+
|
| 76 |
+
tokens = tokenizer.tokenize("This is a test")
|
| 77 |
+
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
| 78 |
+
|
| 79 |
+
self.assertListEqual(
|
| 80 |
+
tokenizer.convert_tokens_to_ids(tokens),
|
| 81 |
+
[285, 46, 10, 170, 382],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
| 85 |
+
self.assertListEqual(
|
| 86 |
+
tokens,
|
| 87 |
+
[
|
| 88 |
+
SPIECE_UNDERLINE + "I",
|
| 89 |
+
SPIECE_UNDERLINE + "was",
|
| 90 |
+
SPIECE_UNDERLINE + "b",
|
| 91 |
+
"or",
|
| 92 |
+
"n",
|
| 93 |
+
SPIECE_UNDERLINE + "in",
|
| 94 |
+
SPIECE_UNDERLINE + "",
|
| 95 |
+
"9",
|
| 96 |
+
"2",
|
| 97 |
+
"0",
|
| 98 |
+
"0",
|
| 99 |
+
"0",
|
| 100 |
+
",",
|
| 101 |
+
SPIECE_UNDERLINE + "and",
|
| 102 |
+
SPIECE_UNDERLINE + "this",
|
| 103 |
+
SPIECE_UNDERLINE + "is",
|
| 104 |
+
SPIECE_UNDERLINE + "f",
|
| 105 |
+
"al",
|
| 106 |
+
"s",
|
| 107 |
+
"é",
|
| 108 |
+
".",
|
| 109 |
+
],
|
| 110 |
+
)
|
| 111 |
+
ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 112 |
+
self.assertListEqual(
|
| 113 |
+
ids,
|
| 114 |
+
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
| 118 |
+
self.assertListEqual(
|
| 119 |
+
back_tokens,
|
| 120 |
+
[
|
| 121 |
+
SPIECE_UNDERLINE + "I",
|
| 122 |
+
SPIECE_UNDERLINE + "was",
|
| 123 |
+
SPIECE_UNDERLINE + "b",
|
| 124 |
+
"or",
|
| 125 |
+
"n",
|
| 126 |
+
SPIECE_UNDERLINE + "in",
|
| 127 |
+
SPIECE_UNDERLINE + "",
|
| 128 |
+
"<unk>",
|
| 129 |
+
"2",
|
| 130 |
+
"0",
|
| 131 |
+
"0",
|
| 132 |
+
"0",
|
| 133 |
+
",",
|
| 134 |
+
SPIECE_UNDERLINE + "and",
|
| 135 |
+
SPIECE_UNDERLINE + "this",
|
| 136 |
+
SPIECE_UNDERLINE + "is",
|
| 137 |
+
SPIECE_UNDERLINE + "f",
|
| 138 |
+
"al",
|
| 139 |
+
"s",
|
| 140 |
+
"<unk>",
|
| 141 |
+
".",
|
| 142 |
+
],
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def test_save_pretrained(self):
|
| 146 |
+
self.tokenizers_list = [
|
| 147 |
+
(self.rust_tokenizer_class, "hf-internal-testing/llama-code-tokenizer", {}),
|
| 148 |
+
(self.tokenizer_class, "hf-internal-testing/llama-code-tokenizer", {}),
|
| 149 |
+
(self.tokenizer_class, "codellama/CodeLlama-34b-Instruct-hf", {}),
|
| 150 |
+
(self.rust_tokenizer_class, "codellama/CodeLlama-34b-Instruct-hf", {}),
|
| 151 |
+
]
|
| 152 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 153 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 154 |
+
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs)
|
| 155 |
+
tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs)
|
| 156 |
+
|
| 157 |
+
tmpdirname2 = tempfile.mkdtemp()
|
| 158 |
+
|
| 159 |
+
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2)
|
| 160 |
+
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
| 161 |
+
|
| 162 |
+
# Checks it save with the same files + the tokenizer.json file for the fast one
|
| 163 |
+
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
|
| 164 |
+
tokenizer_r_files = tuple(f for f in tokenizer_r_files if "tokenizer.json" not in f)
|
| 165 |
+
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
|
| 166 |
+
|
| 167 |
+
# Checks everything loads correctly in the same way
|
| 168 |
+
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
| 169 |
+
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
| 170 |
+
|
| 171 |
+
# Check special tokens are set accordingly on Rust and Python
|
| 172 |
+
for key in tokenizer_pp.special_tokens_map:
|
| 173 |
+
self.assertTrue(hasattr(tokenizer_rp, key))
|
| 174 |
+
|
| 175 |
+
shutil.rmtree(tmpdirname2)
|
| 176 |
+
|
| 177 |
+
# Save tokenizer rust, legacy_format=True
|
| 178 |
+
tmpdirname2 = tempfile.mkdtemp()
|
| 179 |
+
|
| 180 |
+
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=True)
|
| 181 |
+
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
| 182 |
+
|
| 183 |
+
# Checks it save with the same files
|
| 184 |
+
self.assertSequenceEqual(tokenizer_r_files, tokenizer_p_files)
|
| 185 |
+
|
| 186 |
+
# Checks everything loads correctly in the same way
|
| 187 |
+
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
| 188 |
+
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
| 189 |
+
|
| 190 |
+
# Check special tokens are set accordingly on Rust and Python
|
| 191 |
+
for key in tokenizer_pp.special_tokens_map:
|
| 192 |
+
self.assertTrue(hasattr(tokenizer_rp, key))
|
| 193 |
+
|
| 194 |
+
shutil.rmtree(tmpdirname2)
|
| 195 |
+
|
| 196 |
+
# Save tokenizer rust, legacy_format=False
|
| 197 |
+
tmpdirname2 = tempfile.mkdtemp()
|
| 198 |
+
|
| 199 |
+
tokenizer_r_files = tokenizer_r.save_pretrained(tmpdirname2, legacy_format=False)
|
| 200 |
+
tokenizer_p_files = tokenizer_p.save_pretrained(tmpdirname2)
|
| 201 |
+
|
| 202 |
+
# Checks it saved the tokenizer.json file
|
| 203 |
+
self.assertTrue(any("tokenizer.json" in f for f in tokenizer_r_files))
|
| 204 |
+
|
| 205 |
+
# Checks everything loads correctly in the same way
|
| 206 |
+
tokenizer_rp = tokenizer_r.from_pretrained(tmpdirname2)
|
| 207 |
+
tokenizer_pp = tokenizer_p.from_pretrained(tmpdirname2)
|
| 208 |
+
|
| 209 |
+
# Check special tokens are set accordingly on Rust and Python
|
| 210 |
+
for key in tokenizer_pp.special_tokens_map:
|
| 211 |
+
self.assertTrue(hasattr(tokenizer_rp, key))
|
| 212 |
+
|
| 213 |
+
shutil.rmtree(tmpdirname2)
|
| 214 |
+
|
| 215 |
+
@require_torch
|
| 216 |
+
def test_batch_tokenization(self):
|
| 217 |
+
if not self.test_seq2seq:
|
| 218 |
+
self.skipTest(reason="test_seq2seq is False")
|
| 219 |
+
|
| 220 |
+
tokenizers = self.get_tokenizers()
|
| 221 |
+
for tokenizer in tokenizers:
|
| 222 |
+
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
| 223 |
+
# Longer text that will definitely require truncation.
|
| 224 |
+
text = [
|
| 225 |
+
" UN Chief Says There Is No Military Solution in Syria",
|
| 226 |
+
" Secretary-General Ban Ki-moon says his response to Russia's stepped up military support for"
|
| 227 |
+
" Syria is that 'there is no military solution' to the nearly five-year conflict and more weapons"
|
| 228 |
+
" will only worsen the violence and misery for millions of people.",
|
| 229 |
+
]
|
| 230 |
+
try:
|
| 231 |
+
batch = tokenizer(
|
| 232 |
+
text=text,
|
| 233 |
+
max_length=3,
|
| 234 |
+
max_target_length=10,
|
| 235 |
+
return_tensors="pt",
|
| 236 |
+
)
|
| 237 |
+
except NotImplementedError:
|
| 238 |
+
self.skipTest(reason="Encountered NotImplementedError when calling tokenizer")
|
| 239 |
+
self.assertEqual(batch.input_ids.shape[1], 3)
|
| 240 |
+
# max_target_length will default to max_length if not specified
|
| 241 |
+
batch = tokenizer(text, max_length=3, return_tensors="pt")
|
| 242 |
+
self.assertEqual(batch.input_ids.shape[1], 3)
|
| 243 |
+
|
| 244 |
+
batch_encoder_only = tokenizer(text=text, max_length=3, max_target_length=10, return_tensors="pt")
|
| 245 |
+
self.assertEqual(batch_encoder_only.input_ids.shape[1], 3)
|
| 246 |
+
self.assertEqual(batch_encoder_only.attention_mask.shape[1], 3)
|
| 247 |
+
self.assertNotIn("decoder_input_ids", batch_encoder_only)
|
| 248 |
+
|
| 249 |
+
@unittest.skip(reason="Unfortunately way too slow to build a BPE with SentencePiece.")
|
| 250 |
+
def test_save_slow_from_fast_and_reload_fast(self):
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
def test_special_tokens_initialization(self):
|
| 254 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 255 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 256 |
+
added_tokens = [AddedToken("<special>", lstrip=True)]
|
| 257 |
+
|
| 258 |
+
tokenizer_r = self.get_rust_tokenizer(
|
| 259 |
+
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
| 260 |
+
)
|
| 261 |
+
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
| 262 |
+
|
| 263 |
+
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
| 264 |
+
|
| 265 |
+
self.assertTrue(special_token_id in r_output)
|
| 266 |
+
|
| 267 |
+
if self.test_slow_tokenizer:
|
| 268 |
+
tokenizer_cr = self.get_rust_tokenizer(
|
| 269 |
+
pretrained_name,
|
| 270 |
+
additional_special_tokens=added_tokens,
|
| 271 |
+
**kwargs, # , from_slow=True <- unfortunately too slow to convert
|
| 272 |
+
)
|
| 273 |
+
tokenizer_p = self.tokenizer_class.from_pretrained(
|
| 274 |
+
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
| 278 |
+
|
| 279 |
+
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
| 280 |
+
|
| 281 |
+
self.assertEqual(p_output, r_output)
|
| 282 |
+
self.assertEqual(cr_output, r_output)
|
| 283 |
+
self.assertTrue(special_token_id in p_output)
|
| 284 |
+
self.assertTrue(special_token_id in cr_output)
|
| 285 |
+
|
| 286 |
+
@slow
|
| 287 |
+
def test_tokenizer_integration(self):
|
| 288 |
+
expected_encoding = {'input_ids': [[1, 4103, 689, 414, 313, 24784, 368, 2998, 408, 282, 3637, 25350, 29899, 9067, 414, 322, 282, 3637, 25350, 29899, 1457, 3018, 1312, 29899, 2151, 29897, 8128, 2498, 29899, 15503, 4220, 6956, 1973, 313, 13635, 29911, 29892, 402, 7982, 29899, 29906, 29892, 1528, 13635, 29911, 29874, 29892, 1060, 26369, 29892, 6652, 309, 29933, 814, 29892, 1060, 29931, 6779, 11410, 363, 18385, 17088, 7634, 11235, 313, 25103, 29965, 29897, 322, 18385, 17088, 28203, 313, 25103, 29954, 29897, 411, 975, 29871, 29941, 29906, 29974, 758, 3018, 1312, 4733, 297, 29871, 29896, 29900, 29900, 29974, 10276, 322, 6483, 1006, 3372, 3097, 1546, 435, 1165, 29892, 10772, 29911, 25350, 322, 323, 6073, 17907, 29889], [1, 350, 20161, 338, 8688, 304, 758, 29899, 14968, 6483, 21000, 8684, 284, 22540, 515, 443, 29880, 24025, 1426, 491, 14002, 368, 4195, 292, 373, 1716, 2175, 322, 1492, 3030, 297, 599, 15359, 29889], [1, 450, 4996, 17354, 1701, 29916, 432, 17204, 975, 278, 17366, 11203, 29889]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip
|
| 289 |
+
|
| 290 |
+
self.tokenizer_integration_test_util(
|
| 291 |
+
expected_encoding=expected_encoding,
|
| 292 |
+
model_name="hf-internal-testing/llama-code-tokenizer",
|
| 293 |
+
revision="6eb30c03ab6a9e2cdef4d523024909ec815ddb75",
|
| 294 |
+
padding=False,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def test_picklable(self):
|
| 298 |
+
with tempfile.NamedTemporaryFile() as f:
|
| 299 |
+
shutil.copyfile(SAMPLE_VOCAB, f.name)
|
| 300 |
+
tokenizer = CodeLlamaTokenizer(f.name, keep_accents=True)
|
| 301 |
+
pickled_tokenizer = pickle.dumps(tokenizer)
|
| 302 |
+
pickle.loads(pickled_tokenizer)
|
| 303 |
+
|
| 304 |
+
@unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.")
|
| 305 |
+
def test_pickle_subword_regularization_tokenizer(self):
|
| 306 |
+
pass
|
| 307 |
+
|
| 308 |
+
@unittest.skip(reason="worker 'gw4' crashed on CI, passing locally.")
|
| 309 |
+
def test_subword_regularization_tokenizer(self):
|
| 310 |
+
pass
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@require_torch
|
| 314 |
+
@require_sentencepiece
|
| 315 |
+
@require_tokenizers
|
| 316 |
+
class LlamaIntegrationTest(unittest.TestCase):
|
| 317 |
+
@classmethod
|
| 318 |
+
def setUpClass(cls):
|
| 319 |
+
checkpoint_name = "hf-internal-testing/llama-code-tokenizer"
|
| 320 |
+
cls.tokenizer: CodeLlamaTokenizer = CodeLlamaTokenizer.from_pretrained(checkpoint_name)
|
| 321 |
+
cls.rust_tokenizer = CodeLlamaTokenizerFast.from_pretrained(checkpoint_name)
|
| 322 |
+
return cls
|
| 323 |
+
|
| 324 |
+
@require_torch
|
| 325 |
+
def integration_tests(self):
|
| 326 |
+
inputs = self.tokenizer(
|
| 327 |
+
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
|
| 328 |
+
return_tensors="pt",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.assertEqual(
|
| 332 |
+
nested_simplify(inputs),
|
| 333 |
+
{
|
| 334 |
+
"input_ids": [
|
| 335 |
+
[1, 450, 1494, 1347, 881, 367, 6284, 18511, 29901, 15043, 29889],
|
| 336 |
+
[1, 1205, 29871, 1823, 322, 29871, 31010, 30691, 1678, 1823, 1678, 30718],
|
| 337 |
+
],
|
| 338 |
+
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
|
| 339 |
+
},
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
def test_fast_special_tokens(self):
|
| 343 |
+
slow_tokenizer = self.tokenizer
|
| 344 |
+
fast_tokenizer = self.rust_tokenizer
|
| 345 |
+
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 346 |
+
assert slow == [1, 319, 4559, 1243]
|
| 347 |
+
|
| 348 |
+
fast_tokenizer.add_eos_token = False
|
| 349 |
+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 350 |
+
assert fast == [1, 319, 4559, 1243]
|
| 351 |
+
|
| 352 |
+
fast_tokenizer.add_eos_token = True
|
| 353 |
+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 354 |
+
assert fast == [1, 319, 4559, 1243, 2]
|
| 355 |
+
|
| 356 |
+
slow_tokenizer.add_eos_token = True
|
| 357 |
+
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 358 |
+
assert slow == [1, 319, 4559, 1243, 2]
|
| 359 |
+
|
| 360 |
+
fast_tokenizer = CodeLlamaTokenizerFast.from_pretrained(
|
| 361 |
+
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
| 362 |
+
)
|
| 363 |
+
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 364 |
+
assert fast == [319, 4559, 1243, 2]
|
| 365 |
+
|
| 366 |
+
slow_tokenizer = CodeLlamaTokenizer.from_pretrained(
|
| 367 |
+
"hf-internal-testing/llama-tokenizer", add_eos_token=True, add_bos_token=False
|
| 368 |
+
)
|
| 369 |
+
slow = slow_tokenizer.encode("A sample test", add_special_tokens=True)
|
| 370 |
+
assert slow == [319, 4559, 1243, 2]
|
| 371 |
+
|
| 372 |
+
self.tokenizer.add_eos_token = False
|
| 373 |
+
self.rust_tokenizer.add_eos_token = False
|
| 374 |
+
|
| 375 |
+
@slow
|
| 376 |
+
def test_conversion(self):
|
| 377 |
+
# This is excruciatingly slow since it has to recreate the entire merge
|
| 378 |
+
# list from the original vocabulary in spm
|
| 379 |
+
self.rust_tokenizer.save_pretrained("./out")
|
| 380 |
+
with tempfile.TemporaryDirectory() as dirname:
|
| 381 |
+
self.rust_tokenizer.save_pretrained(dirname)
|
| 382 |
+
|
| 383 |
+
with open(os.path.join(dirname, "tokenizer.json")) as f:
|
| 384 |
+
old_serialized = f.read()
|
| 385 |
+
|
| 386 |
+
new_tokenizer = convert_slow_tokenizer(self.tokenizer)
|
| 387 |
+
with tempfile.NamedTemporaryFile() as f:
|
| 388 |
+
new_tokenizer.save(f.name)
|
| 389 |
+
# Re-opening since `f` is in bytes.
|
| 390 |
+
new_serialized = open(f.name).read()
|
| 391 |
+
with open("out_tokenizer.json", "w") as g:
|
| 392 |
+
g.write(new_serialized)
|
| 393 |
+
|
| 394 |
+
self.assertEqual(old_serialized, new_serialized)
|
| 395 |
+
|
| 396 |
+
def test_simple_encode_decode(self):
|
| 397 |
+
pyth_tokenizer = self.tokenizer
|
| 398 |
+
rust_tokenizer = self.rust_tokenizer
|
| 399 |
+
|
| 400 |
+
self.assertEqual(pyth_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
|
| 401 |
+
self.assertEqual(rust_tokenizer.encode("This is a test"), [1, 910, 338, 263, 1243])
|
| 402 |
+
self.assertEqual(pyth_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
|
| 403 |
+
self.assertEqual(rust_tokenizer.decode([1, 910, 338, 263, 1243], skip_special_tokens=True), "This is a test")
|
| 404 |
+
|
| 405 |
+
# bytefallback showcase
|
| 406 |
+
self.assertEqual(pyth_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip
|
| 407 |
+
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392]) # fmt: skip
|
| 408 |
+
self.assertEqual(
|
| 409 |
+
pyth_tokenizer.decode(
|
| 410 |
+
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
|
| 411 |
+
),
|
| 412 |
+
"生活的真谛是",
|
| 413 |
+
)
|
| 414 |
+
self.assertEqual(
|
| 415 |
+
rust_tokenizer.decode(
|
| 416 |
+
[1, 29871, 30486, 31704, 30210, 30848, 235, 179, 158, 30392], skip_special_tokens=True
|
| 417 |
+
),
|
| 418 |
+
"生活的真谛是",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# Inner spaces showcase
|
| 422 |
+
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
|
| 423 |
+
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 29871, 15043])
|
| 424 |
+
self.assertEqual(pyth_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
|
| 425 |
+
self.assertEqual(rust_tokenizer.decode([1, 6324, 29871, 15043], skip_special_tokens=True), "Hi Hello")
|
| 426 |
+
|
| 427 |
+
self.assertEqual(pyth_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
|
| 428 |
+
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [1, 6324, 259, 15043])
|
| 429 |
+
self.assertEqual(pyth_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
|
| 430 |
+
self.assertEqual(rust_tokenizer.decode([1, 6324, 259, 15043], skip_special_tokens=True), "Hi Hello")
|
| 431 |
+
|
| 432 |
+
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
| 433 |
+
self.assertEqual(rust_tokenizer.encode(""), [1])
|
| 434 |
+
|
| 435 |
+
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
| 436 |
+
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
| 437 |
+
|
| 438 |
+
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
| 439 |
+
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
| 440 |
+
|
| 441 |
+
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
| 442 |
+
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
| 443 |
+
|
| 444 |
+
def test_no_differences_showcase(self):
|
| 445 |
+
pyth_tokenizer = self.tokenizer
|
| 446 |
+
rust_tokenizer = self.rust_tokenizer
|
| 447 |
+
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
| 448 |
+
self.assertEqual(rust_tokenizer.encode(""), [1])
|
| 449 |
+
|
| 450 |
+
self.assertEqual(pyth_tokenizer.encode(" "), [1, 259])
|
| 451 |
+
self.assertEqual(rust_tokenizer.encode(" "), [1, 259])
|
| 452 |
+
|
| 453 |
+
self.assertEqual(pyth_tokenizer.encode(" "), [1, 1678])
|
| 454 |
+
self.assertEqual(rust_tokenizer.encode(" "), [1, 1678])
|
| 455 |
+
|
| 456 |
+
self.assertEqual(pyth_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
| 457 |
+
self.assertEqual(rust_tokenizer.encode(" Hello"), [1, 29871, 15043])
|
| 458 |
+
|
| 459 |
+
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
| 460 |
+
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
| 461 |
+
|
| 462 |
+
def test_no_differences_decode(self):
|
| 463 |
+
pyth_tokenizer = self.tokenizer
|
| 464 |
+
rust_tokenizer = self.rust_tokenizer
|
| 465 |
+
|
| 466 |
+
self.assertEqual(pyth_tokenizer.decode([869]), ".")
|
| 467 |
+
self.assertEqual(rust_tokenizer.decode([869]), ".")
|
| 468 |
+
|
| 469 |
+
self.assertEqual(pyth_tokenizer.decode([30112, 869]), "ا .")
|
| 470 |
+
self.assertEqual(rust_tokenizer.decode([30112, 869]), "ا .")
|
| 471 |
+
|
| 472 |
+
def test_no_differences_special_tokens(self):
|
| 473 |
+
pyth_tokenizer = self.tokenizer
|
| 474 |
+
rust_tokenizer = self.rust_tokenizer
|
| 475 |
+
self.assertEqual(pyth_tokenizer.encode(""), [1])
|
| 476 |
+
self.assertEqual(rust_tokenizer.encode(""), [1])
|
| 477 |
+
|
| 478 |
+
self.assertEqual(pyth_tokenizer.encode("<s>"), [1, 1])
|
| 479 |
+
self.assertEqual(rust_tokenizer.encode("<s>"), [1, 1])
|
| 480 |
+
|
| 481 |
+
@unittest.skipIf(
|
| 482 |
+
os.getenv("RUN_TOKENIZER_INTEGRATION", "0") == "0",
|
| 483 |
+
"RUN_TOKENIZER_INTEGRATION=1 to run tokenizer integration tests",
|
| 484 |
+
)
|
| 485 |
+
def test_integration_test_xnli(self):
|
| 486 |
+
import tqdm
|
| 487 |
+
|
| 488 |
+
pyth_tokenizer = self.tokenizer
|
| 489 |
+
rust_tokenizer = self.rust_tokenizer
|
| 490 |
+
|
| 491 |
+
dataset = load_dataset("google/code_x_glue_ct_code_to_text", "go")
|
| 492 |
+
for item in tqdm.tqdm(dataset["validation"]):
|
| 493 |
+
string = item["code"]
|
| 494 |
+
encoded1 = pyth_tokenizer.encode(string)
|
| 495 |
+
encoded2 = rust_tokenizer.encode(string)
|
| 496 |
+
|
| 497 |
+
self.assertEqual(encoded1, encoded2)
|
| 498 |
+
|
| 499 |
+
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
| 500 |
+
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
| 501 |
+
|
| 502 |
+
self.assertEqual(decoded1, decoded2)
|
| 503 |
+
|
| 504 |
+
dataset = load_dataset("facebook/xnli", "all_languages")
|
| 505 |
+
|
| 506 |
+
for item in tqdm.tqdm(dataset["train"]):
|
| 507 |
+
for string in item["premise"].values():
|
| 508 |
+
encoded1 = pyth_tokenizer.encode(string)
|
| 509 |
+
encoded2 = rust_tokenizer.encode(string)
|
| 510 |
+
|
| 511 |
+
self.assertEqual(encoded1, encoded2)
|
| 512 |
+
|
| 513 |
+
decoded1 = pyth_tokenizer.decode(encoded1, skip_special_tokens=True)
|
| 514 |
+
decoded2 = rust_tokenizer.decode(encoded2, skip_special_tokens=True)
|
| 515 |
+
|
| 516 |
+
self.assertEqual(decoded1, decoded2)
|
| 517 |
+
|
| 518 |
+
def test_special_token_special_word(self):
|
| 519 |
+
# the word inform should be split as ['in', 'form']
|
| 520 |
+
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)
|
| 521 |
+
tokenizer.add_tokens([AddedToken("<REPR_END>", rstrip=True, lstrip=True)], special_tokens=False)
|
| 522 |
+
out1 = tokenizer.decode(
|
| 523 |
+
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=False
|
| 524 |
+
)
|
| 525 |
+
self.assertEqual(out1, "<REPR_END>inform")
|
| 526 |
+
out2 = tokenizer.decode(
|
| 527 |
+
tokenizer.encode("<REPR_END>inform", add_special_tokens=False), spaces_between_special_tokens=True
|
| 528 |
+
)
|
| 529 |
+
# the added prefix token should not be decoded
|
| 530 |
+
self.assertEqual(out2, "<REPR_END> inform")
|
| 531 |
+
input_ids = tokenizer.encode("<REPR_END>inform", add_special_tokens=False)
|
| 532 |
+
self.assertEqual(input_ids, [29871, 32016, 262, 689]) # 29871 is the spiece underline, '▁'
|
| 533 |
+
|
| 534 |
+
out2 = tokenizer.decode(
|
| 535 |
+
tokenizer.encode(" <REPR_END> inform", add_special_tokens=False), spaces_between_special_tokens=False
|
| 536 |
+
)
|
| 537 |
+
# TODO @ArthurZ currently we strip left and right, so this will not keep the spaces
|
| 538 |
+
self.assertEqual(out2, "<REPR_END>inform")
|
| 539 |
+
|
| 540 |
+
### Let's make sure decoding does not add extra spaces here and there
|
| 541 |
+
# TODO @ArthurZ this should be affected by the lstrip/rstrip/single word /normalize refactoring
|
| 542 |
+
# Since currently we always strip left and right of the token, results are as such
|
| 543 |
+
input_ids = tokenizer.encode("<s> Hello<s>how", add_special_tokens=False)
|
| 544 |
+
self.assertEqual(input_ids, [1, 15043, 1, 3525])
|
| 545 |
+
tokens = tokenizer.tokenize("<s> Hello<s>how", add_special_tokens=False)
|
| 546 |
+
self.assertEqual(tokens, ["<s>", "▁Hello", "<s>", "how"])
|
| 547 |
+
decoded_tokens = tokenizer.decode(input_ids)
|
| 548 |
+
self.assertEqual(decoded_tokens, "<s> Hello<s>how")
|
| 549 |
+
|
| 550 |
+
# Let's make sure that if there are any spaces, we don't remove them!
|
| 551 |
+
input_ids = tokenizer.encode(" <s> Hello<s> how", add_special_tokens=False)
|
| 552 |
+
self.assertEqual(input_ids, [259, 1, 15043, 1, 920])
|
| 553 |
+
tokens = tokenizer.tokenize(" <s> Hello<s> how", add_special_tokens=False)
|
| 554 |
+
self.assertEqual(tokens, ["▁▁", "<s>", "▁Hello", "<s>", "▁how"])
|
| 555 |
+
decoded_tokens = tokenizer.decode(input_ids)
|
| 556 |
+
self.assertEqual(decoded_tokens, " <s> Hello<s> how")
|
| 557 |
+
|
| 558 |
+
def test_fill_token(self):
|
| 559 |
+
tokenizer = CodeLlamaTokenizerFast.from_pretrained(
|
| 560 |
+
"codellama/CodeLlama-7b-hf", fill_token=None, prefix_token=None, suffix_token=None, middle_token=None
|
| 561 |
+
)
|
| 562 |
+
tokenizer.encode_plus("Hey how are you").input_ids
|
| 563 |
+
tokenizer.fill_token = "<FILL_ME>"
|
| 564 |
+
with self.assertRaises(ValueError):
|
| 565 |
+
tokenizer.encode("Hey how <FILL_ME> are you")
|
| 566 |
+
tokenizer.encode_plus("Hey how <FILL_ME> are you", "mne too")
|
| 567 |
+
tokenizer.tokenize("Hey how are you", "mne too")
|
| 568 |
+
|
| 569 |
+
tokenizer = CodeLlamaTokenizerFast.from_pretrained(
|
| 570 |
+
"codellama/CodeLlama-7b-hf", revision="3773f63b4511b9e47a9a7ffc765eed7eb0169486"
|
| 571 |
+
)
|
| 572 |
+
tokenizer.encode("Hey how <FILL_ME> are you")
|
| 573 |
+
tokenizer.encode_plus("Hey how <FILL_ME> are you", "mne too")
|
| 574 |
+
tokenizer.tokenize("Hey how are you", "mne too")
|
| 575 |
+
|
| 576 |
+
def test_spm_edge_cases(self):
|
| 577 |
+
# the word inform should be split as ['in', 'form']
|
| 578 |
+
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-hf", legacy=False)
|
| 579 |
+
tokens = tokenizer.tokenize("[INST] How are you doing?<s>[/INST]")
|
| 580 |
+
self.assertEqual(
|
| 581 |
+
tokens, ["▁[", "INST", "]", "▁How", "▁are", "▁you", "▁doing", "?", "<s>", "[", "/", "INST", "]"]
|
| 582 |
+
)
|
| 583 |
+
inputs_ids = tokenizer.encode("[INST] How are you doing?<s>[/INST]")
|
| 584 |
+
self.assertEqual(
|
| 585 |
+
inputs_ids, [1, 518, 25580, 29962, 1128, 526, 366, 2599, 29973, 1, 29961, 29914, 25580, 29962]
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
def test_infilling_tokenization(self):
|
| 589 |
+
PROMPTS = [
|
| 590 |
+
'''def remove_non_ascii(s: str) -> str:
|
| 591 |
+
""" <FILL_ME>
|
| 592 |
+
return result
|
| 593 |
+
''',
|
| 594 |
+
"""# Installation instructions:
|
| 595 |
+
```bash
|
| 596 |
+
<FILL_ME>
|
| 597 |
+
```
|
| 598 |
+
This downloads the LLaMA inference code and installs the repository as a local pip package.
|
| 599 |
+
""",
|
| 600 |
+
"""class InterfaceManagerFactory(AbstractManagerFactory):
|
| 601 |
+
def __init__(<FILL_ME>
|
| 602 |
+
def main():
|
| 603 |
+
factory = InterfaceManagerFactory(start=datetime.now())
|
| 604 |
+
managers = []
|
| 605 |
+
for i in range(10):
|
| 606 |
+
managers.append(factory.build(id=i))
|
| 607 |
+
""",
|
| 608 |
+
"""/-- A quasi-prefunctoid is 1-connected iff all its etalisations are 1-connected. -/
|
| 609 |
+
theorem connected_iff_etalisation [C D : precategoroid] (P : quasi_prefunctoid C D) :
|
| 610 |
+
π₁ P = 0 ↔ <FILL_ME> = 0 :=
|
| 611 |
+
begin
|
| 612 |
+
split,
|
| 613 |
+
{ intros h f,
|
| 614 |
+
rw pi_1_etalisation at h,
|
| 615 |
+
simp [h],
|
| 616 |
+
refl
|
| 617 |
+
},
|
| 618 |
+
{ intro h,
|
| 619 |
+
have := @quasi_adjoint C D P,
|
| 620 |
+
simp [←pi_1_etalisation, this, h],
|
| 621 |
+
refl
|
| 622 |
+
}
|
| 623 |
+
end
|
| 624 |
+
""",
|
| 625 |
+
]
|
| 626 |
+
tokenizer = CodeLlamaTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
|
| 627 |
+
tokenizer_fast = CodeLlamaTokenizerFast.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
|
| 628 |
+
|
| 629 |
+
formatted_prompt = tokenizer.tokenize(PROMPTS[0])
|
| 630 |
+
self.assertEqual(formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0]))
|
| 631 |
+
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
| 632 |
+
self.assertEqual(formatted_prompt, tokenizer.tokenize(prefix, suffix))
|
| 633 |
+
self.assertEqual(formatted_prompt, tokenizer_fast.tokenize(prefix, suffix))
|
| 634 |
+
|
| 635 |
+
input_ids = tokenizer.encode(PROMPTS[0], add_special_tokens=False)
|
| 636 |
+
self.assertEqual(input_ids, tokenizer_fast.encode(PROMPTS[0], add_special_tokens=False))
|
| 637 |
+
|
| 638 |
+
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
| 639 |
+
input_ids = tokenizer.encode(PROMPTS[0])
|
| 640 |
+
self.assertEqual(input_ids, tokenizer.encode(prefix, suffix=suffix))
|
| 641 |
+
self.assertEqual(tokenizer.encode(prefix, suffix=suffix), tokenizer_fast.encode(prefix, suffix=suffix))
|
| 642 |
+
|
| 643 |
+
# Adding suffix_first check for infilling tasks
|
| 644 |
+
suffix_first_formatted_prompt = tokenizer.tokenize(PROMPTS[0], suffix_first=True)
|
| 645 |
+
self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(PROMPTS[0], suffix_first=True))
|
| 646 |
+
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
| 647 |
+
self.assertEqual(suffix_first_formatted_prompt, tokenizer.tokenize(prefix, suffix, suffix_first=True))
|
| 648 |
+
self.assertEqual(suffix_first_formatted_prompt, tokenizer_fast.tokenize(prefix, suffix, suffix_first=True))
|
| 649 |
+
|
| 650 |
+
prefix, suffix = PROMPTS[0].split("<FILL_ME>")
|
| 651 |
+
suffix_first_input_ids = tokenizer.encode(PROMPTS[0], suffix_first=True)
|
| 652 |
+
self.assertEqual(suffix_first_input_ids, tokenizer.encode(prefix, suffix=suffix, suffix_first=True))
|
| 653 |
+
self.assertEqual(suffix_first_input_ids, tokenizer_fast.encode(prefix, suffix=suffix, suffix_first=True))
|
docs/transformers/tests/models/codegen/__init__.py
ADDED
|
File without changes
|
docs/transformers/tests/models/codegen/test_modeling_codegen.py
ADDED
|
@@ -0,0 +1,492 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import unittest
|
| 17 |
+
|
| 18 |
+
from transformers import CodeGenConfig, is_torch_available
|
| 19 |
+
from transformers.file_utils import cached_property
|
| 20 |
+
from transformers.testing_utils import backend_manual_seed, require_torch, slow, torch_device
|
| 21 |
+
|
| 22 |
+
from ...generation.test_utils import GenerationTesterMixin
|
| 23 |
+
from ...test_configuration_common import ConfigTester
|
| 24 |
+
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
| 25 |
+
from ...test_pipeline_mixin import PipelineTesterMixin
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_torch_available():
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
from transformers import AutoTokenizer, CodeGenForCausalLM, CodeGenModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CodeGenModelTester:
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
parent,
|
| 38 |
+
batch_size=14,
|
| 39 |
+
seq_length=7,
|
| 40 |
+
is_training=True,
|
| 41 |
+
use_token_type_ids=True,
|
| 42 |
+
use_input_mask=True,
|
| 43 |
+
use_labels=True,
|
| 44 |
+
use_mc_token_ids=True,
|
| 45 |
+
vocab_size=256,
|
| 46 |
+
hidden_size=32,
|
| 47 |
+
rotary_dim=4,
|
| 48 |
+
num_hidden_layers=2,
|
| 49 |
+
num_attention_heads=4,
|
| 50 |
+
intermediate_size=37,
|
| 51 |
+
hidden_act="gelu",
|
| 52 |
+
hidden_dropout_prob=0.0,
|
| 53 |
+
attention_probs_dropout_prob=0.0,
|
| 54 |
+
max_position_embeddings=512,
|
| 55 |
+
type_vocab_size=16,
|
| 56 |
+
type_sequence_label_size=2,
|
| 57 |
+
initializer_range=0.02,
|
| 58 |
+
num_labels=3,
|
| 59 |
+
num_choices=4,
|
| 60 |
+
):
|
| 61 |
+
self.parent = parent
|
| 62 |
+
self.batch_size = batch_size
|
| 63 |
+
self.seq_length = seq_length
|
| 64 |
+
self.is_training = is_training
|
| 65 |
+
self.use_token_type_ids = use_token_type_ids
|
| 66 |
+
self.use_input_mask = use_input_mask
|
| 67 |
+
self.use_labels = use_labels
|
| 68 |
+
self.use_mc_token_ids = use_mc_token_ids
|
| 69 |
+
self.vocab_size = vocab_size
|
| 70 |
+
self.hidden_size = hidden_size
|
| 71 |
+
self.rotary_dim = rotary_dim
|
| 72 |
+
self.num_hidden_layers = num_hidden_layers
|
| 73 |
+
self.num_attention_heads = num_attention_heads
|
| 74 |
+
self.intermediate_size = intermediate_size
|
| 75 |
+
self.hidden_act = hidden_act
|
| 76 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 77 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 78 |
+
self.max_position_embeddings = max_position_embeddings
|
| 79 |
+
self.type_vocab_size = type_vocab_size
|
| 80 |
+
self.type_sequence_label_size = type_sequence_label_size
|
| 81 |
+
self.initializer_range = initializer_range
|
| 82 |
+
self.num_labels = num_labels
|
| 83 |
+
self.num_choices = num_choices
|
| 84 |
+
self.scope = None
|
| 85 |
+
self.bos_token_id = vocab_size - 1
|
| 86 |
+
self.eos_token_id = vocab_size - 1
|
| 87 |
+
self.pad_token_id = vocab_size - 1
|
| 88 |
+
|
| 89 |
+
def get_large_model_config(self):
|
| 90 |
+
return CodeGenConfig.from_pretrained("Salesforce/codegen-2B-mono")
|
| 91 |
+
|
| 92 |
+
def prepare_config_and_inputs(self):
|
| 93 |
+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
| 94 |
+
|
| 95 |
+
input_mask = None
|
| 96 |
+
if self.use_input_mask:
|
| 97 |
+
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
| 98 |
+
|
| 99 |
+
token_type_ids = None
|
| 100 |
+
if self.use_token_type_ids:
|
| 101 |
+
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
| 102 |
+
|
| 103 |
+
mc_token_ids = None
|
| 104 |
+
if self.use_mc_token_ids:
|
| 105 |
+
mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
|
| 106 |
+
|
| 107 |
+
sequence_labels = None
|
| 108 |
+
token_labels = None
|
| 109 |
+
choice_labels = None
|
| 110 |
+
if self.use_labels:
|
| 111 |
+
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
| 112 |
+
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
| 113 |
+
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
| 114 |
+
|
| 115 |
+
config = self.get_config()
|
| 116 |
+
|
| 117 |
+
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
| 118 |
+
|
| 119 |
+
return (
|
| 120 |
+
config,
|
| 121 |
+
input_ids,
|
| 122 |
+
input_mask,
|
| 123 |
+
head_mask,
|
| 124 |
+
token_type_ids,
|
| 125 |
+
mc_token_ids,
|
| 126 |
+
sequence_labels,
|
| 127 |
+
token_labels,
|
| 128 |
+
choice_labels,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def get_config(self):
|
| 132 |
+
return CodeGenConfig(
|
| 133 |
+
vocab_size=self.vocab_size,
|
| 134 |
+
n_embd=self.hidden_size,
|
| 135 |
+
n_layer=self.num_hidden_layers,
|
| 136 |
+
n_head=self.num_attention_heads,
|
| 137 |
+
intermediate_size=self.intermediate_size,
|
| 138 |
+
hidden_act=self.hidden_act,
|
| 139 |
+
hidden_dropout_prob=self.hidden_dropout_prob,
|
| 140 |
+
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
| 141 |
+
n_positions=self.max_position_embeddings,
|
| 142 |
+
type_vocab_size=self.type_vocab_size,
|
| 143 |
+
initializer_range=self.initializer_range,
|
| 144 |
+
use_cache=True,
|
| 145 |
+
bos_token_id=self.bos_token_id,
|
| 146 |
+
eos_token_id=self.eos_token_id,
|
| 147 |
+
pad_token_id=self.pad_token_id,
|
| 148 |
+
rotary_dim=self.rotary_dim,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def create_and_check_codegen_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
| 152 |
+
model = CodeGenModel(config=config)
|
| 153 |
+
model.to(torch_device)
|
| 154 |
+
model.eval()
|
| 155 |
+
|
| 156 |
+
result = model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
|
| 157 |
+
result = model(input_ids, token_type_ids=token_type_ids)
|
| 158 |
+
result = model(input_ids)
|
| 159 |
+
|
| 160 |
+
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
| 161 |
+
self.parent.assertEqual(len(result.past_key_values), config.n_layer)
|
| 162 |
+
|
| 163 |
+
def create_and_check_codegen_model_past(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
| 164 |
+
model = CodeGenModel(config=config)
|
| 165 |
+
model.to(torch_device)
|
| 166 |
+
model.eval()
|
| 167 |
+
|
| 168 |
+
# first forward pass
|
| 169 |
+
outputs = model(input_ids, token_type_ids=token_type_ids, use_cache=True)
|
| 170 |
+
outputs_use_cache_conf = model(input_ids, token_type_ids=token_type_ids)
|
| 171 |
+
outputs_no_past = model(input_ids, token_type_ids=token_type_ids, use_cache=False)
|
| 172 |
+
|
| 173 |
+
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
|
| 174 |
+
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
|
| 175 |
+
|
| 176 |
+
output, past = outputs.to_tuple()
|
| 177 |
+
|
| 178 |
+
# create hypothetical next token and extent to next_input_ids
|
| 179 |
+
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
| 180 |
+
next_token_types = ids_tensor([self.batch_size, 1], self.type_vocab_size)
|
| 181 |
+
|
| 182 |
+
# append to next input_ids and token_type_ids
|
| 183 |
+
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
| 184 |
+
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
|
| 185 |
+
|
| 186 |
+
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
|
| 187 |
+
output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past)[
|
| 188 |
+
"last_hidden_state"
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
# select random slice
|
| 192 |
+
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
| 193 |
+
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
| 194 |
+
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
| 195 |
+
|
| 196 |
+
# test that outputs are equal for slice
|
| 197 |
+
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
| 198 |
+
|
| 199 |
+
def create_and_check_codegen_model_attention_mask_past(
|
| 200 |
+
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
| 201 |
+
):
|
| 202 |
+
model = CodeGenModel(config=config)
|
| 203 |
+
model.to(torch_device)
|
| 204 |
+
model.eval()
|
| 205 |
+
|
| 206 |
+
# create attention mask
|
| 207 |
+
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
| 208 |
+
half_seq_length = self.seq_length // 2
|
| 209 |
+
attn_mask[:, half_seq_length:] = 0
|
| 210 |
+
|
| 211 |
+
# first forward pass
|
| 212 |
+
output, past = model(input_ids, attention_mask=attn_mask).to_tuple()
|
| 213 |
+
|
| 214 |
+
# create hypothetical next token and extent to next_input_ids
|
| 215 |
+
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
| 216 |
+
|
| 217 |
+
# change a random masked slice from input_ids
|
| 218 |
+
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
|
| 219 |
+
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
|
| 220 |
+
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
|
| 221 |
+
|
| 222 |
+
# append to next input_ids and attn_mask
|
| 223 |
+
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
| 224 |
+
attn_mask = torch.cat(
|
| 225 |
+
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
|
| 226 |
+
dim=1,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# get two different outputs
|
| 230 |
+
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
|
| 231 |
+
output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"]
|
| 232 |
+
|
| 233 |
+
# select random slice
|
| 234 |
+
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
| 235 |
+
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
|
| 236 |
+
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
|
| 237 |
+
|
| 238 |
+
# test that outputs are equal for slice
|
| 239 |
+
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
| 240 |
+
|
| 241 |
+
def create_and_check_codegen_model_past_large_inputs(
|
| 242 |
+
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
|
| 243 |
+
):
|
| 244 |
+
model = CodeGenModel(config=config)
|
| 245 |
+
model.to(torch_device)
|
| 246 |
+
model.eval()
|
| 247 |
+
|
| 248 |
+
# first forward pass
|
| 249 |
+
outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=input_mask, use_cache=True)
|
| 250 |
+
|
| 251 |
+
output, past = outputs.to_tuple()
|
| 252 |
+
|
| 253 |
+
# create hypothetical next token and extent to next_input_ids
|
| 254 |
+
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
|
| 255 |
+
next_token_types = ids_tensor([self.batch_size, 3], self.type_vocab_size)
|
| 256 |
+
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
|
| 257 |
+
|
| 258 |
+
# append to next input_ids and token_type_ids
|
| 259 |
+
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
| 260 |
+
next_token_type_ids = torch.cat([token_type_ids, next_token_types], dim=-1)
|
| 261 |
+
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
|
| 262 |
+
|
| 263 |
+
output_from_no_past = model(
|
| 264 |
+
next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
|
| 265 |
+
)["last_hidden_state"]
|
| 266 |
+
output_from_past = model(
|
| 267 |
+
next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past_key_values=past
|
| 268 |
+
)["last_hidden_state"]
|
| 269 |
+
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
|
| 270 |
+
|
| 271 |
+
# select random slice
|
| 272 |
+
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
| 273 |
+
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
|
| 274 |
+
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
|
| 275 |
+
|
| 276 |
+
# test that outputs are equal for slice
|
| 277 |
+
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
|
| 278 |
+
|
| 279 |
+
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
| 280 |
+
model = CodeGenForCausalLM(config)
|
| 281 |
+
model.to(torch_device)
|
| 282 |
+
model.eval()
|
| 283 |
+
|
| 284 |
+
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
| 285 |
+
self.parent.assertEqual(result.loss.shape, ())
|
| 286 |
+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
| 287 |
+
|
| 288 |
+
def create_and_check_forward_and_backwards(
|
| 289 |
+
self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
|
| 290 |
+
):
|
| 291 |
+
model = CodeGenForCausalLM(config)
|
| 292 |
+
if gradient_checkpointing:
|
| 293 |
+
model.gradient_checkpointing_enable()
|
| 294 |
+
model.to(torch_device)
|
| 295 |
+
|
| 296 |
+
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
|
| 297 |
+
self.parent.assertEqual(result.loss.shape, ())
|
| 298 |
+
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
| 299 |
+
result.loss.backward()
|
| 300 |
+
|
| 301 |
+
def prepare_config_and_inputs_for_common(self):
|
| 302 |
+
config_and_inputs = self.prepare_config_and_inputs()
|
| 303 |
+
|
| 304 |
+
(
|
| 305 |
+
config,
|
| 306 |
+
input_ids,
|
| 307 |
+
input_mask,
|
| 308 |
+
head_mask,
|
| 309 |
+
token_type_ids,
|
| 310 |
+
mc_token_ids,
|
| 311 |
+
sequence_labels,
|
| 312 |
+
token_labels,
|
| 313 |
+
choice_labels,
|
| 314 |
+
) = config_and_inputs
|
| 315 |
+
|
| 316 |
+
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask}
|
| 317 |
+
|
| 318 |
+
return config, inputs_dict
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
@require_torch
|
| 322 |
+
class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
| 323 |
+
all_model_classes = (CodeGenModel, CodeGenForCausalLM) if is_torch_available() else ()
|
| 324 |
+
pipeline_model_mapping = (
|
| 325 |
+
{"feature-extraction": CodeGenModel, "text-generation": CodeGenForCausalLM} if is_torch_available() else {}
|
| 326 |
+
)
|
| 327 |
+
fx_compatible = False
|
| 328 |
+
test_pruning = False
|
| 329 |
+
test_missing_keys = False
|
| 330 |
+
test_model_parallel = False
|
| 331 |
+
test_head_masking = False
|
| 332 |
+
|
| 333 |
+
# special case for DoubleHeads model
|
| 334 |
+
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
| 335 |
+
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
| 336 |
+
return inputs_dict
|
| 337 |
+
|
| 338 |
+
def setUp(self):
|
| 339 |
+
self.model_tester = CodeGenModelTester(self)
|
| 340 |
+
self.config_tester = ConfigTester(self, config_class=CodeGenConfig, n_embd=37)
|
| 341 |
+
|
| 342 |
+
def test_config(self):
|
| 343 |
+
self.config_tester.run_common_tests()
|
| 344 |
+
|
| 345 |
+
def test_codegen_model(self):
|
| 346 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 347 |
+
self.model_tester.create_and_check_codegen_model(*config_and_inputs)
|
| 348 |
+
|
| 349 |
+
def test_codegen_model_past(self):
|
| 350 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 351 |
+
self.model_tester.create_and_check_codegen_model_past(*config_and_inputs)
|
| 352 |
+
|
| 353 |
+
def test_codegen_model_att_mask_past(self):
|
| 354 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 355 |
+
self.model_tester.create_and_check_codegen_model_attention_mask_past(*config_and_inputs)
|
| 356 |
+
|
| 357 |
+
def test_codegen_model_past_large_inputs(self):
|
| 358 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 359 |
+
self.model_tester.create_and_check_codegen_model_past_large_inputs(*config_and_inputs)
|
| 360 |
+
|
| 361 |
+
def test_codegen_lm_head_model(self):
|
| 362 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 363 |
+
self.model_tester.create_and_check_lm_head_model(*config_and_inputs)
|
| 364 |
+
|
| 365 |
+
def test_codegen_gradient_checkpointing(self):
|
| 366 |
+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
| 367 |
+
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True)
|
| 368 |
+
|
| 369 |
+
@slow
|
| 370 |
+
def test_batch_generation(self):
|
| 371 |
+
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
| 372 |
+
model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
| 373 |
+
model.to(torch_device)
|
| 374 |
+
|
| 375 |
+
tokenizer.padding_side = "left"
|
| 376 |
+
|
| 377 |
+
# Define PAD Token = EOS Token = 50256
|
| 378 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 379 |
+
model.config.pad_token_id = model.config.eos_token_id
|
| 380 |
+
|
| 381 |
+
# use different length sentences to test batching
|
| 382 |
+
sentences = ["def hellow_world():", "def greet(name):"]
|
| 383 |
+
|
| 384 |
+
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
| 385 |
+
input_ids = inputs["input_ids"].to(torch_device)
|
| 386 |
+
token_type_ids = torch.cat(
|
| 387 |
+
[
|
| 388 |
+
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
|
| 389 |
+
input_ids.new_full((input_ids.shape[0], 1), 500),
|
| 390 |
+
],
|
| 391 |
+
dim=-1,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
outputs = model.generate(
|
| 395 |
+
input_ids=input_ids,
|
| 396 |
+
attention_mask=inputs["attention_mask"].to(torch_device),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
outputs_tt = model.generate(
|
| 400 |
+
input_ids=input_ids,
|
| 401 |
+
attention_mask=inputs["attention_mask"].to(torch_device),
|
| 402 |
+
token_type_ids=token_type_ids,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
|
| 406 |
+
output_non_padded = model.generate(input_ids=inputs_non_padded)
|
| 407 |
+
|
| 408 |
+
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item()
|
| 409 |
+
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
|
| 410 |
+
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
|
| 411 |
+
|
| 412 |
+
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 413 |
+
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
|
| 414 |
+
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
|
| 415 |
+
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
|
| 416 |
+
|
| 417 |
+
expected_output_sentence = [
|
| 418 |
+
'def hellow_world():\n print("Hello World")\n\nhellow_world()',
|
| 419 |
+
'def greet(name):\n print(f"Hello {name}")\n\ng',
|
| 420 |
+
]
|
| 421 |
+
self.assertListEqual(expected_output_sentence, batch_out_sentence)
|
| 422 |
+
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
|
| 423 |
+
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
|
| 424 |
+
|
| 425 |
+
@slow
|
| 426 |
+
def test_model_from_pretrained(self):
|
| 427 |
+
model_name = "Salesforce/codegen-350M-nl"
|
| 428 |
+
model = CodeGenModel.from_pretrained(model_name)
|
| 429 |
+
self.assertIsNotNone(model)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@require_torch
|
| 433 |
+
class CodeGenModelLanguageGenerationTest(unittest.TestCase):
|
| 434 |
+
@cached_property
|
| 435 |
+
def cached_tokenizer(self):
|
| 436 |
+
return AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
| 437 |
+
|
| 438 |
+
@cached_property
|
| 439 |
+
def cached_model(self):
|
| 440 |
+
return CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")
|
| 441 |
+
|
| 442 |
+
@slow
|
| 443 |
+
def test_lm_generate_codegen(self):
|
| 444 |
+
tokenizer = self.cached_tokenizer
|
| 445 |
+
for checkpointing in [True, False]:
|
| 446 |
+
model = self.cached_model
|
| 447 |
+
|
| 448 |
+
if checkpointing:
|
| 449 |
+
model.gradient_checkpointing_enable()
|
| 450 |
+
else:
|
| 451 |
+
model.gradient_checkpointing_disable()
|
| 452 |
+
model.to(torch_device)
|
| 453 |
+
|
| 454 |
+
inputs = tokenizer("def hello_world():", return_tensors="pt").to(torch_device)
|
| 455 |
+
expected_output = 'def hello_world():\n print("Hello World")\n\nhello_world()\n\n'
|
| 456 |
+
|
| 457 |
+
output_ids = model.generate(**inputs, do_sample=False)
|
| 458 |
+
output_str = tokenizer.batch_decode(output_ids)[0]
|
| 459 |
+
|
| 460 |
+
self.assertEqual(output_str, expected_output)
|
| 461 |
+
|
| 462 |
+
@slow
|
| 463 |
+
def test_codegen_sample(self):
|
| 464 |
+
tokenizer = self.cached_tokenizer
|
| 465 |
+
model = self.cached_model
|
| 466 |
+
model.to(torch_device)
|
| 467 |
+
|
| 468 |
+
torch.manual_seed(0)
|
| 469 |
+
backend_manual_seed(torch_device, 0)
|
| 470 |
+
|
| 471 |
+
tokenized = tokenizer("def hello_world():", return_tensors="pt", return_token_type_ids=True)
|
| 472 |
+
input_ids = tokenized.input_ids.to(torch_device)
|
| 473 |
+
output_ids = model.generate(input_ids, do_sample=True)
|
| 474 |
+
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 475 |
+
|
| 476 |
+
token_type_ids = tokenized.token_type_ids.to(torch_device)
|
| 477 |
+
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
|
| 478 |
+
output_seq_tt = model.generate(
|
| 479 |
+
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
|
| 480 |
+
)
|
| 481 |
+
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
|
| 482 |
+
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
|
| 483 |
+
|
| 484 |
+
if torch_device == "cuda":
|
| 485 |
+
EXPECTED_OUTPUT_STR = 'def hello_world():\n print("Hello World")\n return True\n\nresult ='
|
| 486 |
+
else:
|
| 487 |
+
EXPECTED_OUTPUT_STR = "def hello_world():\r\n print('Hello, World.')\r\n\r\n\r"
|
| 488 |
+
|
| 489 |
+
self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
|
| 490 |
+
self.assertTrue(
|
| 491 |
+
all(output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs)))
|
| 492 |
+
) # token_type_ids should change output
|
docs/transformers/tests/models/codegen/test_tokenization_codegen.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import unittest
|
| 20 |
+
from functools import lru_cache
|
| 21 |
+
|
| 22 |
+
from transformers import CodeGenTokenizer, CodeGenTokenizerFast
|
| 23 |
+
from transformers.models.codegen.tokenization_codegen import VOCAB_FILES_NAMES
|
| 24 |
+
from transformers.testing_utils import require_tokenizers, slow
|
| 25 |
+
|
| 26 |
+
from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@require_tokenizers
|
| 30 |
+
class CodeGenTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
| 31 |
+
from_pretrained_id = "Salesforce/codegen-350M-mono"
|
| 32 |
+
tokenizer_class = CodeGenTokenizer
|
| 33 |
+
rust_tokenizer_class = CodeGenTokenizerFast
|
| 34 |
+
test_rust_tokenizer = True
|
| 35 |
+
from_pretrained_kwargs = {"add_prefix_space": True}
|
| 36 |
+
test_seq2seq = False
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def setUpClass(cls):
|
| 40 |
+
super().setUpClass()
|
| 41 |
+
|
| 42 |
+
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
| 43 |
+
vocab = [
|
| 44 |
+
"l",
|
| 45 |
+
"o",
|
| 46 |
+
"w",
|
| 47 |
+
"e",
|
| 48 |
+
"r",
|
| 49 |
+
"s",
|
| 50 |
+
"t",
|
| 51 |
+
"i",
|
| 52 |
+
"d",
|
| 53 |
+
"n",
|
| 54 |
+
"\u0120",
|
| 55 |
+
"\u0120l",
|
| 56 |
+
"\u0120n",
|
| 57 |
+
"\u0120lo",
|
| 58 |
+
"\u0120low",
|
| 59 |
+
"er",
|
| 60 |
+
"\u0120lowest",
|
| 61 |
+
"\u0120newer",
|
| 62 |
+
"\u0120wider",
|
| 63 |
+
"<unk>",
|
| 64 |
+
"<|endoftext|>",
|
| 65 |
+
]
|
| 66 |
+
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
| 67 |
+
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
| 68 |
+
cls.special_tokens_map = {"unk_token": "<unk>"}
|
| 69 |
+
|
| 70 |
+
cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
| 71 |
+
cls.merges_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
| 72 |
+
with open(cls.vocab_file, "w", encoding="utf-8") as fp:
|
| 73 |
+
fp.write(json.dumps(vocab_tokens) + "\n")
|
| 74 |
+
with open(cls.merges_file, "w", encoding="utf-8") as fp:
|
| 75 |
+
fp.write("\n".join(merges))
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
@use_cache_if_possible
|
| 79 |
+
@lru_cache(maxsize=64)
|
| 80 |
+
def get_tokenizer(cls, pretrained_name=None, **kwargs):
|
| 81 |
+
kwargs.update(cls.special_tokens_map)
|
| 82 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 83 |
+
return CodeGenTokenizer.from_pretrained(pretrained_name, **kwargs)
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
@use_cache_if_possible
|
| 87 |
+
@lru_cache(maxsize=64)
|
| 88 |
+
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs):
|
| 89 |
+
kwargs.update(cls.special_tokens_map)
|
| 90 |
+
pretrained_name = pretrained_name or cls.tmpdirname
|
| 91 |
+
return CodeGenTokenizerFast.from_pretrained(pretrained_name, **kwargs)
|
| 92 |
+
|
| 93 |
+
def get_input_output_texts(self, tokenizer):
|
| 94 |
+
input_text = "lower newer"
|
| 95 |
+
output_text = "lower newer"
|
| 96 |
+
return input_text, output_text
|
| 97 |
+
|
| 98 |
+
def test_full_tokenizer(self):
|
| 99 |
+
tokenizer = CodeGenTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
| 100 |
+
text = "lower newer"
|
| 101 |
+
bpe_tokens = ["\u0120low", "er", "\u0120", "n", "e", "w", "er"]
|
| 102 |
+
tokens = tokenizer.tokenize(text, add_prefix_space=True)
|
| 103 |
+
self.assertListEqual(tokens, bpe_tokens)
|
| 104 |
+
|
| 105 |
+
input_tokens = tokens + [tokenizer.unk_token]
|
| 106 |
+
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
| 107 |
+
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
| 108 |
+
|
| 109 |
+
def test_rust_and_python_full_tokenizers(self):
|
| 110 |
+
if not self.test_rust_tokenizer:
|
| 111 |
+
self.skipTest(reason="test_rust_tokenizer is set to False")
|
| 112 |
+
|
| 113 |
+
tokenizer = self.get_tokenizer()
|
| 114 |
+
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
| 115 |
+
|
| 116 |
+
sequence = "lower newer"
|
| 117 |
+
|
| 118 |
+
# Testing tokenization
|
| 119 |
+
tokens = tokenizer.tokenize(sequence, add_prefix_space=True)
|
| 120 |
+
rust_tokens = rust_tokenizer.tokenize(sequence)
|
| 121 |
+
self.assertListEqual(tokens, rust_tokens)
|
| 122 |
+
|
| 123 |
+
# Testing conversion to ids without special tokens
|
| 124 |
+
ids = tokenizer.encode(sequence, add_special_tokens=False, add_prefix_space=True)
|
| 125 |
+
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
|
| 126 |
+
self.assertListEqual(ids, rust_ids)
|
| 127 |
+
|
| 128 |
+
# Testing conversion to ids with special tokens
|
| 129 |
+
rust_tokenizer = self.get_rust_tokenizer(add_prefix_space=True)
|
| 130 |
+
ids = tokenizer.encode(sequence, add_prefix_space=True)
|
| 131 |
+
rust_ids = rust_tokenizer.encode(sequence)
|
| 132 |
+
self.assertListEqual(ids, rust_ids)
|
| 133 |
+
|
| 134 |
+
# Testing the unknown token
|
| 135 |
+
input_tokens = tokens + [rust_tokenizer.unk_token]
|
| 136 |
+
input_bpe_tokens = [14, 15, 10, 9, 3, 2, 15, 19]
|
| 137 |
+
self.assertListEqual(rust_tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
| 138 |
+
|
| 139 |
+
@unittest.skip
|
| 140 |
+
def test_pretokenized_inputs(self, *args, **kwargs):
|
| 141 |
+
# It's very difficult to mix/test pretokenization with byte-level
|
| 142 |
+
# And get both CodeGen and Roberta to work at the same time (mostly an issue of adding a space before the string)
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def test_padding(self, max_length=15):
|
| 146 |
+
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
| 147 |
+
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
| 148 |
+
tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs)
|
| 149 |
+
|
| 150 |
+
# Simple input
|
| 151 |
+
s = "This is a simple input"
|
| 152 |
+
s2 = ["This is a simple input 1", "This is a simple input 2"]
|
| 153 |
+
p = ("This is a simple input", "This is a pair")
|
| 154 |
+
p2 = [
|
| 155 |
+
("This is a simple input 1", "This is a simple input 2"),
|
| 156 |
+
("This is a simple pair 1", "This is a simple pair 2"),
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
# Simple input tests
|
| 160 |
+
self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length")
|
| 161 |
+
|
| 162 |
+
# Simple input
|
| 163 |
+
self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length")
|
| 164 |
+
|
| 165 |
+
# Simple input
|
| 166 |
+
self.assertRaises(
|
| 167 |
+
ValueError,
|
| 168 |
+
tokenizer_r.batch_encode_plus,
|
| 169 |
+
s2,
|
| 170 |
+
max_length=max_length,
|
| 171 |
+
padding="max_length",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Pair input
|
| 175 |
+
self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length")
|
| 176 |
+
|
| 177 |
+
# Pair input
|
| 178 |
+
self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length")
|
| 179 |
+
|
| 180 |
+
# Pair input
|
| 181 |
+
self.assertRaises(
|
| 182 |
+
ValueError,
|
| 183 |
+
tokenizer_r.batch_encode_plus,
|
| 184 |
+
p2,
|
| 185 |
+
max_length=max_length,
|
| 186 |
+
padding="max_length",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def test_padding_if_pad_token_set_slow(self):
|
| 190 |
+
tokenizer = CodeGenTokenizer.from_pretrained(self.tmpdirname, pad_token="<pad>")
|
| 191 |
+
|
| 192 |
+
# Simple input
|
| 193 |
+
s = "This is a simple input"
|
| 194 |
+
s2 = ["This is a simple input looooooooong", "This is a simple input"]
|
| 195 |
+
p = ("This is a simple input", "This is a pair")
|
| 196 |
+
p2 = [
|
| 197 |
+
("This is a simple input loooooong", "This is a simple input"),
|
| 198 |
+
("This is a simple pair loooooong", "This is a simple pair"),
|
| 199 |
+
]
|
| 200 |
+
|
| 201 |
+
pad_token_id = tokenizer.pad_token_id
|
| 202 |
+
|
| 203 |
+
out_s = tokenizer(s, padding="max_length", max_length=30, return_tensors="np")
|
| 204 |
+
out_s2 = tokenizer(s2, padding=True, truncate=True, return_tensors="np")
|
| 205 |
+
out_p = tokenizer(*p, padding="max_length", max_length=60, return_tensors="np")
|
| 206 |
+
out_p2 = tokenizer(p2, padding=True, truncate=True, return_tensors="np")
|
| 207 |
+
|
| 208 |
+
# s
|
| 209 |
+
# test single string max_length padding
|
| 210 |
+
self.assertEqual(out_s["input_ids"].shape[-1], 30)
|
| 211 |
+
self.assertTrue(pad_token_id in out_s["input_ids"])
|
| 212 |
+
self.assertTrue(0 in out_s["attention_mask"])
|
| 213 |
+
|
| 214 |
+
# s2
|
| 215 |
+
# test automatic padding
|
| 216 |
+
self.assertEqual(out_s2["input_ids"].shape[-1], 33)
|
| 217 |
+
# long slice doesn't have padding
|
| 218 |
+
self.assertFalse(pad_token_id in out_s2["input_ids"][0])
|
| 219 |
+
self.assertFalse(0 in out_s2["attention_mask"][0])
|
| 220 |
+
# short slice does have padding
|
| 221 |
+
self.assertTrue(pad_token_id in out_s2["input_ids"][1])
|
| 222 |
+
self.assertTrue(0 in out_s2["attention_mask"][1])
|
| 223 |
+
|
| 224 |
+
# p
|
| 225 |
+
# test single pair max_length padding
|
| 226 |
+
self.assertEqual(out_p["input_ids"].shape[-1], 60)
|
| 227 |
+
self.assertTrue(pad_token_id in out_p["input_ids"])
|
| 228 |
+
self.assertTrue(0 in out_p["attention_mask"])
|
| 229 |
+
|
| 230 |
+
# p2
|
| 231 |
+
# test automatic padding pair
|
| 232 |
+
self.assertEqual(out_p2["input_ids"].shape[-1], 52)
|
| 233 |
+
# long slice pair doesn't have padding
|
| 234 |
+
self.assertFalse(pad_token_id in out_p2["input_ids"][0])
|
| 235 |
+
self.assertFalse(0 in out_p2["attention_mask"][0])
|
| 236 |
+
# short slice pair does have padding
|
| 237 |
+
self.assertTrue(pad_token_id in out_p2["input_ids"][1])
|
| 238 |
+
self.assertTrue(0 in out_p2["attention_mask"][1])
|
| 239 |
+
|
| 240 |
+
def test_add_bos_token_slow(self):
|
| 241 |
+
bos_token = "$$$"
|
| 242 |
+
tokenizer = CodeGenTokenizer.from_pretrained(self.tmpdirname, bos_token=bos_token, add_bos_token=True)
|
| 243 |
+
|
| 244 |
+
s = "This is a simple input"
|
| 245 |
+
s2 = ["This is a simple input 1", "This is a simple input 2"]
|
| 246 |
+
|
| 247 |
+
bos_token_id = tokenizer.bos_token_id
|
| 248 |
+
|
| 249 |
+
out_s = tokenizer(s)
|
| 250 |
+
out_s2 = tokenizer(s2)
|
| 251 |
+
|
| 252 |
+
self.assertEqual(out_s.input_ids[0], bos_token_id)
|
| 253 |
+
self.assertTrue(all(o[0] == bos_token_id for o in out_s2.input_ids))
|
| 254 |
+
|
| 255 |
+
decode_s = tokenizer.decode(out_s.input_ids)
|
| 256 |
+
decode_s2 = tokenizer.batch_decode(out_s2.input_ids)
|
| 257 |
+
|
| 258 |
+
self.assertTrue(decode_s.startswith(bos_token))
|
| 259 |
+
self.assertTrue(all(d.startswith(bos_token) for d in decode_s2))
|
| 260 |
+
|
| 261 |
+
@slow
|
| 262 |
+
def test_truncation(self):
|
| 263 |
+
tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
|
| 264 |
+
|
| 265 |
+
text = "\nif len_a > len_b:\n result = a\nelse:\n result = b\n\n\n\n#"
|
| 266 |
+
expected_truncated_text = "\nif len_a > len_b:\n result = a\nelse:\n result = b"
|
| 267 |
+
|
| 268 |
+
input_ids = tokenizer.encode(text)
|
| 269 |
+
truncation_pattern = ["^#", re.escape("<|endoftext|>"), "^'''", '^"""', "\n\n\n"]
|
| 270 |
+
decoded_text = tokenizer.decode(input_ids, truncate_before_pattern=truncation_pattern)
|
| 271 |
+
self.assertEqual(decoded_text, expected_truncated_text)
|
| 272 |
+
# TODO @ArthurZ outputs of the fast tokenizer are different in this case, un-related to the PR
|
| 273 |
+
|
| 274 |
+
# tokenizer has no padding token
|
| 275 |
+
@unittest.skip(reason="tokenizer has no padding token")
|
| 276 |
+
def test_padding_different_model_input_name(self):
|
| 277 |
+
pass
|
| 278 |
+
|
| 279 |
+
@slow
|
| 280 |
+
def test_tokenizer_integration(self):
|
| 281 |
+
# Custom test since this tokenizer takes return_token_type_ids as an init argument for backward compatibility.
|
| 282 |
+
|
| 283 |
+
sequences = [
|
| 284 |
+
"Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides "
|
| 285 |
+
"general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural "
|
| 286 |
+
"Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained "
|
| 287 |
+
"models in 100+ languages and deep interoperability between Jax, PyTorch and TensorFlow.",
|
| 288 |
+
"BERT is designed to pre-train deep bidirectional representations from unlabeled text by jointly "
|
| 289 |
+
"conditioning on both left and right context in all layers.",
|
| 290 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 291 |
+
]
|
| 292 |
+
|
| 293 |
+
tokenizer_classes = [self.tokenizer_class]
|
| 294 |
+
if self.test_rust_tokenizer:
|
| 295 |
+
tokenizer_classes.append(self.rust_tokenizer_class)
|
| 296 |
+
|
| 297 |
+
# Test default case. i.e. return_token_type_ids is False.
|
| 298 |
+
for tokenizer_class in tokenizer_classes:
|
| 299 |
+
tokenizer = tokenizer_class.from_pretrained("Salesforce/codegen-350M-mono")
|
| 300 |
+
|
| 301 |
+
encoding = tokenizer(sequences)
|
| 302 |
+
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
|
| 303 |
+
|
| 304 |
+
# fmt: off
|
| 305 |
+
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
|
| 306 |
+
# fmt: on
|
| 307 |
+
|
| 308 |
+
encoding_data = encoding.data
|
| 309 |
+
self.assertDictEqual(encoding_data, expected_encoding)
|
| 310 |
+
|
| 311 |
+
for expected, decoded in zip(sequences, decoded_sequences):
|
| 312 |
+
self.assertEqual(expected, decoded)
|
| 313 |
+
|
| 314 |
+
# Test return_token_type_ids is True case.
|
| 315 |
+
for tokenizer_class in tokenizer_classes:
|
| 316 |
+
tokenizer = tokenizer_class.from_pretrained("Salesforce/codegen-350M-mono", return_token_type_ids=True)
|
| 317 |
+
|
| 318 |
+
encoding = tokenizer(sequences)
|
| 319 |
+
decoded_sequences = [tokenizer.decode(seq, skip_special_tokens=True) for seq in encoding["input_ids"]]
|
| 320 |
+
|
| 321 |
+
# fmt: off
|
| 322 |
+
expected_encoding = {'input_ids': [[41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13], [13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13], [464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # noqa: E501
|
| 323 |
+
# fmt: on
|
| 324 |
+
|
| 325 |
+
encoding_data = encoding.data
|
| 326 |
+
self.assertDictEqual(encoding_data, expected_encoding)
|
| 327 |
+
|
| 328 |
+
for expected, decoded in zip(sequences, decoded_sequences):
|
| 329 |
+
self.assertEqual(expected, decoded)
|
docs/transformers/tests/models/cohere/__init__.py
ADDED
|
File without changes
|