Student0809 commited on
Commit
cb2428f
·
verified ·
1 Parent(s): b381930

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. docs/transformers/tests/trainer/test_trainer_fsdp.py +175 -0
  2. docs/transformers/tests/utils/import_structures/failing_export.py +23 -0
  3. docs/transformers/tests/utils/import_structures/import_structure_raw_register.py +80 -0
  4. docs/transformers/tests/utils/import_structures/import_structure_register_with_comments.py +79 -0
  5. docs/transformers/tests/utils/import_structures/import_structure_register_with_duplicates.py +77 -0
  6. docs/transformers/tests/utils/test_activations_tf.py +60 -0
  7. docs/transformers/tests/utils/test_add_new_model_like.py +1506 -0
  8. docs/transformers/tests/utils/test_audio_utils.py +1751 -0
  9. docs/transformers/tests/utils/test_backbone_utils.py +272 -0
  10. docs/transformers/tests/utils/test_cache_utils.py +766 -0
  11. docs/transformers/tests/utils/test_chat_template_utils.py +501 -0
  12. docs/transformers/tests/utils/test_cli.py +77 -0
  13. docs/transformers/tests/utils/test_configuration_utils.py +302 -0
  14. docs/transformers/tests/utils/test_convert_slow_tokenizer.py +35 -0
  15. docs/transformers/tests/utils/test_deprecation.py +195 -0
  16. docs/transformers/tests/utils/test_doc_samples.py +112 -0
  17. docs/transformers/tests/utils/test_dynamic_module_utils.py +129 -0
  18. docs/transformers/tests/utils/test_expectations.py +34 -0
  19. docs/transformers/tests/utils/test_feature_extraction_utils.py +121 -0
  20. docs/transformers/tests/utils/test_file_utils.py +133 -0
  21. docs/transformers/tests/utils/test_generic.py +463 -0
  22. docs/transformers/tests/utils/test_hf_argparser.py +482 -0
  23. docs/transformers/tests/utils/test_hub_utils.py +200 -0
  24. docs/transformers/tests/utils/test_image_processing_utils.py +181 -0
  25. docs/transformers/tests/utils/test_image_utils.py +1061 -0
  26. docs/transformers/tests/utils/test_import_structure.py +104 -0
  27. docs/transformers/tests/utils/test_import_utils.py +26 -0
  28. docs/transformers/tests/utils/test_logging.py +135 -0
  29. docs/transformers/tests/utils/test_model_card.py +88 -0
  30. docs/transformers/tests/utils/test_model_debugging_utils.py +122 -0
  31. docs/transformers/tests/utils/test_model_output.py +201 -0
  32. docs/transformers/tests/utils/test_modeling_flax_utils.py +285 -0
  33. docs/transformers/tests/utils/test_modeling_rope_utils.py +453 -0
  34. docs/transformers/tests/utils/test_modeling_tf_core.py +403 -0
  35. docs/transformers/tests/utils/test_modeling_tf_utils.py +662 -0
  36. docs/transformers/tests/utils/test_modeling_utils.py +0 -0
  37. docs/transformers/tests/utils/test_offline.py +220 -0
  38. docs/transformers/tests/utils/test_processing_utils.py +175 -0
  39. docs/transformers/tests/utils/test_skip_decorators.py +119 -0
  40. docs/transformers/tests/utils/test_tokenization_utils.py +303 -0
  41. docs/transformers/tests/utils/test_versions_utils.py +97 -0
  42. docs/transformers/tests/utils/tiny_model_summary.json +0 -0
  43. docs/transformers/utils/add_pipeline_model_mapping_to_test.py +336 -0
  44. docs/transformers/utils/check_bad_commit.py +202 -0
  45. docs/transformers/utils/check_build.py +49 -0
  46. docs/transformers/utils/check_config_attributes.py +470 -0
  47. docs/transformers/utils/check_config_docstrings.py +102 -0
  48. docs/transformers/utils/check_copies.py +1078 -0
  49. docs/transformers/utils/check_doc_toc.py +134 -0
  50. docs/transformers/utils/check_docstrings.py +1061 -0
docs/transformers/tests/trainer/test_trainer_fsdp.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from transformers import is_torch_available
17
+ from transformers.testing_utils import (
18
+ TestCasePlus,
19
+ backend_device_count,
20
+ execute_subprocess_async,
21
+ get_torch_dist_unique_port,
22
+ require_accelerate,
23
+ require_fp8,
24
+ require_torch_multi_accelerator,
25
+ run_first,
26
+ torch_device,
27
+ )
28
+
29
+
30
+ if is_torch_available():
31
+ import torch
32
+ import torch.distributed
33
+ import torch.utils.data
34
+
35
+ from transformers import (
36
+ AutoModelForCausalLM,
37
+ AutoTokenizer,
38
+ DataCollatorForSeq2Seq,
39
+ EvalPrediction,
40
+ GenerationConfig,
41
+ HfArgumentParser,
42
+ PreTrainedTokenizerBase,
43
+ Seq2SeqTrainer,
44
+ Seq2SeqTrainingArguments,
45
+ )
46
+
47
+ class DummyTextDataset(torch.utils.data.Dataset[str]):
48
+ def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
49
+ data = 4 * [
50
+ "Hello world!",
51
+ "The quick brown fox jumps over the lazy dog.",
52
+ ]
53
+ self.data = [
54
+ {k: v.squeeze(0) for k, v in tokenizer(item, return_tensors="pt", return_attention_mask=True).items()}
55
+ for item in data
56
+ ]
57
+ for item in self.data:
58
+ item["labels"] = item["input_ids"]
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.data)
62
+
63
+ def __getitem__(self, i: int) -> str:
64
+ return self.data[i]
65
+
66
+
67
+ class TestFSDPTrainer(TestCasePlus):
68
+ @require_torch_multi_accelerator
69
+ @require_accelerate
70
+ @run_first
71
+ def test_trainer(self):
72
+ output_dir = self.get_auto_remove_tmp_dir()
73
+ cmd = [
74
+ "accelerate",
75
+ "launch",
76
+ "--use_fsdp",
77
+ "--main_process_port",
78
+ f"{get_torch_dist_unique_port()}",
79
+ "--num_processes",
80
+ f"{backend_device_count(torch_device)}",
81
+ "--fsdp_transformer_layer_cls_to_wrap",
82
+ "GPT2Block",
83
+ f"{self.test_file_dir}/test_trainer_fsdp.py",
84
+ "--output_dir",
85
+ f"{output_dir}",
86
+ "--report_to",
87
+ "none",
88
+ ]
89
+ execute_subprocess_async(cmd, env=self.get_env())
90
+ # successful return here == success - any errors would have caused an error in the sub-call
91
+
92
+
93
+ class TestFSDPTrainerFP8(TestCasePlus):
94
+ @require_torch_multi_accelerator
95
+ @require_accelerate
96
+ @require_fp8
97
+ @run_first
98
+ def test_trainer(self):
99
+ output_dir = self.get_auto_remove_tmp_dir()
100
+ cmd = [
101
+ "accelerate",
102
+ "launch",
103
+ "--use_fsdp",
104
+ "--main_process_port",
105
+ f"{get_torch_dist_unique_port()}",
106
+ "--num_processes",
107
+ f"{backend_device_count(torch_device)}",
108
+ "--mixed_precision",
109
+ "fp8",
110
+ "--fsdp_transformer_layer_cls_to_wrap",
111
+ "GPT2Block",
112
+ f"{self.test_file_dir}/test_trainer_fsdp.py",
113
+ "--output_dir",
114
+ f"{output_dir}",
115
+ "--report_to",
116
+ "none",
117
+ ]
118
+ execute_subprocess_async(cmd, env=self.get_env())
119
+ # successful return here == success - any errors would have caused an error in the sub-call
120
+
121
+
122
+ class TestFSDPTrainerWrap(TestCasePlus):
123
+ @require_torch_multi_accelerator
124
+ @require_accelerate
125
+ @run_first
126
+ def test_trainer(self):
127
+ output_dir = self.get_auto_remove_tmp_dir()
128
+ cmd = [
129
+ "accelerate",
130
+ "launch",
131
+ "--use_fsdp",
132
+ "--main_process_port",
133
+ f"{get_torch_dist_unique_port()}",
134
+ "--num_processes",
135
+ f"{backend_device_count(torch_device)}",
136
+ "--fsdp_transformer_layer_cls_to_wrap",
137
+ "GPT2Block",
138
+ f"{self.test_file_dir}/test_trainer_fsdp.py",
139
+ "--output_dir",
140
+ f"{output_dir}",
141
+ "--report_to",
142
+ "none",
143
+ "--auto_find_batch_size",
144
+ "True",
145
+ ]
146
+ execute_subprocess_async(cmd, env=self.get_env())
147
+ # successful return here == success - any errors would have caused an error in the sub-call
148
+
149
+
150
+ if __name__ == "__main__":
151
+ parser = HfArgumentParser((Seq2SeqTrainingArguments,))
152
+ training_args = parser.parse_args_into_dataclasses()[0]
153
+ training_args.per_device_eval_batch_size = 1
154
+ training_args.use_legacy_prediction_loop = False
155
+ training_args.predict_with_generate = True
156
+ training_args.generation_config = GenerationConfig(max_length=30)
157
+
158
+ pretrained_model_name = "hf-internal-testing/tiny-random-gpt2"
159
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
160
+ tokenizer.pad_token = tokenizer.eos_token
161
+ device = torch.device(torch.distributed.get_rank())
162
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name).to(device)
163
+
164
+ def compute_metrics(p: EvalPrediction) -> dict[str, bool]:
165
+ return {"accuracy": (p.predictions == p.label_ids).mean()}
166
+
167
+ trainer = Seq2SeqTrainer(
168
+ model=model,
169
+ args=training_args,
170
+ data_collator=DataCollatorForSeq2Seq(tokenizer, model),
171
+ eval_dataset=DummyTextDataset(tokenizer),
172
+ compute_metrics=compute_metrics,
173
+ )
174
+
175
+ metrics = trainer.evaluate()
docs/transformers/tests/utils/import_structures/failing_export.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # fmt: off
16
+
17
+ from transformers.utils.import_utils import requires
18
+
19
+
20
+ @requires(backends=("random_item_that_should_not_exist",))
21
+ class A0:
22
+ def __init__(self):
23
+ pass
docs/transformers/tests/utils/import_structures/import_structure_raw_register.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # fmt: off
16
+
17
+ from transformers.utils.import_utils import requires
18
+
19
+
20
+ @requires()
21
+ class A0:
22
+ def __init__(self):
23
+ pass
24
+
25
+
26
+ @requires()
27
+ def a0():
28
+ pass
29
+
30
+
31
+ @requires(backends=("torch", "tf"))
32
+ class A1:
33
+ def __init__(self):
34
+ pass
35
+
36
+
37
+ @requires(backends=("torch", "tf"))
38
+ def a1():
39
+ pass
40
+
41
+
42
+ @requires(
43
+ backends=("torch", "tf")
44
+ )
45
+ class A2:
46
+ def __init__(self):
47
+ pass
48
+
49
+
50
+ @requires(
51
+ backends=("torch", "tf")
52
+ )
53
+ def a2():
54
+ pass
55
+
56
+
57
+ @requires(
58
+ backends=(
59
+ "torch",
60
+ "tf"
61
+ )
62
+ )
63
+ class A3:
64
+ def __init__(self):
65
+ pass
66
+
67
+
68
+ @requires(
69
+ backends=(
70
+ "torch",
71
+ "tf"
72
+ )
73
+ )
74
+ def a3():
75
+ pass
76
+
77
+ @requires(backends=())
78
+ class A4:
79
+ def __init__(self):
80
+ pass
docs/transformers/tests/utils/import_structures/import_structure_register_with_comments.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # fmt: off
16
+
17
+ from transformers.utils.import_utils import requires
18
+
19
+
20
+ @requires()
21
+ # That's a statement
22
+ class B0:
23
+ def __init__(self):
24
+ pass
25
+
26
+
27
+ @requires()
28
+ # That's a statement
29
+ def b0():
30
+ pass
31
+
32
+
33
+ @requires(backends=("torch", "tf"))
34
+ # That's a statement
35
+ class B1:
36
+ def __init__(self):
37
+ pass
38
+
39
+
40
+ @requires(backends=("torch", "tf"))
41
+ # That's a statement
42
+ def b1():
43
+ pass
44
+
45
+
46
+ @requires(backends=("torch", "tf"))
47
+ # That's a statement
48
+ class B2:
49
+ def __init__(self):
50
+ pass
51
+
52
+
53
+ @requires(backends=("torch", "tf"))
54
+ # That's a statement
55
+ def b2():
56
+ pass
57
+
58
+
59
+ @requires(
60
+ backends=(
61
+ "torch",
62
+ "tf"
63
+ )
64
+ )
65
+ # That's a statement
66
+ class B3:
67
+ def __init__(self):
68
+ pass
69
+
70
+
71
+ @requires(
72
+ backends=(
73
+ "torch",
74
+ "tf"
75
+ )
76
+ )
77
+ # That's a statement
78
+ def b3():
79
+ pass
docs/transformers/tests/utils/import_structures/import_structure_register_with_duplicates.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # fmt: off
16
+
17
+ from transformers.utils.import_utils import requires
18
+
19
+
20
+ @requires(backends=("torch", "torch"))
21
+ class C0:
22
+ def __init__(self):
23
+ pass
24
+
25
+
26
+ @requires(backends=("torch", "torch"))
27
+ def c0():
28
+ pass
29
+
30
+
31
+ @requires(backends=("torch", "torch"))
32
+ # That's a statement
33
+ class C1:
34
+ def __init__(self):
35
+ pass
36
+
37
+
38
+ @requires(backends=("torch", "torch"))
39
+ # That's a statement
40
+ def c1():
41
+ pass
42
+
43
+
44
+ @requires(backends=("torch", "torch"))
45
+ # That's a statement
46
+ class C2:
47
+ def __init__(self):
48
+ pass
49
+
50
+
51
+ @requires(backends=("torch", "torch"))
52
+ # That's a statement
53
+ def c2():
54
+ pass
55
+
56
+
57
+ @requires(
58
+ backends=(
59
+ "torch",
60
+ "torch"
61
+ )
62
+ )
63
+ # That's a statement
64
+ class C3:
65
+ def __init__(self):
66
+ pass
67
+
68
+
69
+ @requires(
70
+ backends=(
71
+ "torch",
72
+ "torch"
73
+ )
74
+ )
75
+ # That's a statement
76
+ def c3():
77
+ pass
docs/transformers/tests/utils/test_activations_tf.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import numpy as np
18
+
19
+ from transformers import is_tf_available
20
+ from transformers.testing_utils import require_tf
21
+
22
+
23
+ if is_tf_available():
24
+ import tensorflow as tf
25
+
26
+ from transformers.activations_tf import get_tf_activation
27
+
28
+
29
+ @require_tf
30
+ class TestTFActivations(unittest.TestCase):
31
+ def test_gelu_10(self):
32
+ x = tf.constant([-100, -1.0, -0.1, 0, 0.1, 1.0, 100.0])
33
+ gelu = get_tf_activation("gelu")
34
+ gelu10 = get_tf_activation("gelu_10")
35
+
36
+ y_gelu = gelu(x)
37
+ y_gelu_10 = gelu10(x)
38
+
39
+ clipped_mask = tf.where(y_gelu_10 < 10.0, 1.0, 0.0)
40
+
41
+ self.assertEqual(tf.math.reduce_max(y_gelu_10).numpy().item(), 10.0)
42
+ self.assertTrue(np.allclose(y_gelu * clipped_mask, y_gelu_10 * clipped_mask))
43
+
44
+ def test_get_activation(self):
45
+ get_tf_activation("gelu")
46
+ get_tf_activation("gelu_10")
47
+ get_tf_activation("gelu_fast")
48
+ get_tf_activation("gelu_new")
49
+ get_tf_activation("glu")
50
+ get_tf_activation("mish")
51
+ get_tf_activation("quick_gelu")
52
+ get_tf_activation("relu")
53
+ get_tf_activation("sigmoid")
54
+ get_tf_activation("silu")
55
+ get_tf_activation("swish")
56
+ get_tf_activation("tanh")
57
+ with self.assertRaises(KeyError):
58
+ get_tf_activation("bogus")
59
+ with self.assertRaises(KeyError):
60
+ get_tf_activation(None)
docs/transformers/tests/utils/test_add_new_model_like.py ADDED
@@ -0,0 +1,1506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import re
16
+ import tempfile
17
+ import unittest
18
+ from pathlib import Path
19
+
20
+ import transformers
21
+ from transformers.commands.add_new_model_like import (
22
+ ModelPatterns,
23
+ _re_class_func,
24
+ add_content_to_file,
25
+ add_content_to_text,
26
+ clean_frameworks_in_init,
27
+ duplicate_doc_file,
28
+ duplicate_module,
29
+ filter_framework_files,
30
+ find_base_model_checkpoint,
31
+ get_model_files,
32
+ get_module_from_file,
33
+ parse_module_content,
34
+ replace_model_patterns,
35
+ retrieve_info_for_model,
36
+ retrieve_model_classes,
37
+ simplify_replacements,
38
+ )
39
+ from transformers.testing_utils import require_flax, require_tf, require_torch
40
+
41
+
42
+ BERT_MODEL_FILES = {
43
+ "src/transformers/models/bert/__init__.py",
44
+ "src/transformers/models/bert/configuration_bert.py",
45
+ "src/transformers/models/bert/tokenization_bert.py",
46
+ "src/transformers/models/bert/tokenization_bert_fast.py",
47
+ "src/transformers/models/bert/tokenization_bert_tf.py",
48
+ "src/transformers/models/bert/modeling_bert.py",
49
+ "src/transformers/models/bert/modeling_flax_bert.py",
50
+ "src/transformers/models/bert/modeling_tf_bert.py",
51
+ "src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py",
52
+ "src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py",
53
+ "src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py",
54
+ "src/transformers/models/bert/convert_bert_token_dropping_original_tf2_checkpoint_to_pytorch.py",
55
+ }
56
+
57
+ VIT_MODEL_FILES = {
58
+ "src/transformers/models/vit/__init__.py",
59
+ "src/transformers/models/vit/configuration_vit.py",
60
+ "src/transformers/models/vit/convert_dino_to_pytorch.py",
61
+ "src/transformers/models/vit/convert_vit_timm_to_pytorch.py",
62
+ "src/transformers/models/vit/feature_extraction_vit.py",
63
+ "src/transformers/models/vit/image_processing_vit.py",
64
+ "src/transformers/models/vit/image_processing_vit_fast.py",
65
+ "src/transformers/models/vit/modeling_vit.py",
66
+ "src/transformers/models/vit/modeling_tf_vit.py",
67
+ "src/transformers/models/vit/modeling_flax_vit.py",
68
+ }
69
+
70
+ WAV2VEC2_MODEL_FILES = {
71
+ "src/transformers/models/wav2vec2/__init__.py",
72
+ "src/transformers/models/wav2vec2/configuration_wav2vec2.py",
73
+ "src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py",
74
+ "src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py",
75
+ "src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py",
76
+ "src/transformers/models/wav2vec2/modeling_wav2vec2.py",
77
+ "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
78
+ "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
79
+ "src/transformers/models/wav2vec2/processing_wav2vec2.py",
80
+ "src/transformers/models/wav2vec2/tokenization_wav2vec2.py",
81
+ }
82
+
83
+ REPO_PATH = Path(transformers.__path__[0]).parent.parent
84
+
85
+
86
+ @require_torch
87
+ @require_tf
88
+ @require_flax
89
+ class TestAddNewModelLike(unittest.TestCase):
90
+ def init_file(self, file_name, content):
91
+ with open(file_name, "w", encoding="utf-8") as f:
92
+ f.write(content)
93
+
94
+ def check_result(self, file_name, expected_result):
95
+ with open(file_name, encoding="utf-8") as f:
96
+ result = f.read()
97
+ self.assertEqual(result, expected_result)
98
+
99
+ def test_re_class_func(self):
100
+ self.assertEqual(_re_class_func.search("def my_function(x, y):").groups()[0], "my_function")
101
+ self.assertEqual(_re_class_func.search("class MyClass:").groups()[0], "MyClass")
102
+ self.assertEqual(_re_class_func.search("class MyClass(SuperClass):").groups()[0], "MyClass")
103
+
104
+ def test_model_patterns_defaults(self):
105
+ model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")
106
+
107
+ self.assertEqual(model_patterns.model_type, "gpt-new-new")
108
+ self.assertEqual(model_patterns.model_lower_cased, "gpt_new_new")
109
+ self.assertEqual(model_patterns.model_camel_cased, "GPTNewNew")
110
+ self.assertEqual(model_patterns.model_upper_cased, "GPT_NEW_NEW")
111
+ self.assertEqual(model_patterns.config_class, "GPTNewNewConfig")
112
+ self.assertIsNone(model_patterns.tokenizer_class)
113
+ self.assertIsNone(model_patterns.feature_extractor_class)
114
+ self.assertIsNone(model_patterns.processor_class)
115
+
116
+ def test_parse_module_content(self):
117
+ test_code = """SOME_CONSTANT = a constant
118
+
119
+ CONSTANT_DEFINED_ON_SEVERAL_LINES = [
120
+ first_item,
121
+ second_item
122
+ ]
123
+
124
+ def function(args):
125
+ some code
126
+
127
+ # Copied from transformers.some_module
128
+ class SomeClass:
129
+ some code
130
+ """
131
+
132
+ expected_parts = [
133
+ "SOME_CONSTANT = a constant\n",
134
+ "CONSTANT_DEFINED_ON_SEVERAL_LINES = [\n first_item,\n second_item\n]",
135
+ "",
136
+ "def function(args):\n some code\n",
137
+ "# Copied from transformers.some_module\nclass SomeClass:\n some code\n",
138
+ ]
139
+ self.assertEqual(parse_module_content(test_code), expected_parts)
140
+
141
+ def test_add_content_to_text(self):
142
+ test_text = """all_configs = {
143
+ "gpt": "GPTConfig",
144
+ "bert": "BertConfig",
145
+ "t5": "T5Config",
146
+ }"""
147
+
148
+ expected = """all_configs = {
149
+ "gpt": "GPTConfig",
150
+ "gpt2": "GPT2Config",
151
+ "bert": "BertConfig",
152
+ "t5": "T5Config",
153
+ }"""
154
+ line = ' "gpt2": "GPT2Config",'
155
+
156
+ self.assertEqual(add_content_to_text(test_text, line, add_before="bert"), expected)
157
+ self.assertEqual(add_content_to_text(test_text, line, add_before="bert", exact_match=True), test_text)
158
+ self.assertEqual(
159
+ add_content_to_text(test_text, line, add_before=' "bert": "BertConfig",', exact_match=True), expected
160
+ )
161
+ self.assertEqual(add_content_to_text(test_text, line, add_before=re.compile(r'^\s*"bert":')), expected)
162
+
163
+ self.assertEqual(add_content_to_text(test_text, line, add_after="gpt"), expected)
164
+ self.assertEqual(add_content_to_text(test_text, line, add_after="gpt", exact_match=True), test_text)
165
+ self.assertEqual(
166
+ add_content_to_text(test_text, line, add_after=' "gpt": "GPTConfig",', exact_match=True), expected
167
+ )
168
+ self.assertEqual(add_content_to_text(test_text, line, add_after=re.compile(r'^\s*"gpt":')), expected)
169
+
170
+ def test_add_content_to_file(self):
171
+ test_text = """all_configs = {
172
+ "gpt": "GPTConfig",
173
+ "bert": "BertConfig",
174
+ "t5": "T5Config",
175
+ }"""
176
+
177
+ expected = """all_configs = {
178
+ "gpt": "GPTConfig",
179
+ "gpt2": "GPT2Config",
180
+ "bert": "BertConfig",
181
+ "t5": "T5Config",
182
+ }"""
183
+ line = ' "gpt2": "GPT2Config",'
184
+
185
+ with tempfile.TemporaryDirectory() as tmp_dir:
186
+ file_name = os.path.join(tmp_dir, "code.py")
187
+
188
+ self.init_file(file_name, test_text)
189
+ add_content_to_file(file_name, line, add_before="bert")
190
+ self.check_result(file_name, expected)
191
+
192
+ self.init_file(file_name, test_text)
193
+ add_content_to_file(file_name, line, add_before="bert", exact_match=True)
194
+ self.check_result(file_name, test_text)
195
+
196
+ self.init_file(file_name, test_text)
197
+ add_content_to_file(file_name, line, add_before=' "bert": "BertConfig",', exact_match=True)
198
+ self.check_result(file_name, expected)
199
+
200
+ self.init_file(file_name, test_text)
201
+ add_content_to_file(file_name, line, add_before=re.compile(r'^\s*"bert":'))
202
+ self.check_result(file_name, expected)
203
+
204
+ self.init_file(file_name, test_text)
205
+ add_content_to_file(file_name, line, add_after="gpt")
206
+ self.check_result(file_name, expected)
207
+
208
+ self.init_file(file_name, test_text)
209
+ add_content_to_file(file_name, line, add_after="gpt", exact_match=True)
210
+ self.check_result(file_name, test_text)
211
+
212
+ self.init_file(file_name, test_text)
213
+ add_content_to_file(file_name, line, add_after=' "gpt": "GPTConfig",', exact_match=True)
214
+ self.check_result(file_name, expected)
215
+
216
+ self.init_file(file_name, test_text)
217
+ add_content_to_file(file_name, line, add_after=re.compile(r'^\s*"gpt":'))
218
+ self.check_result(file_name, expected)
219
+
220
+ def test_simplify_replacements(self):
221
+ self.assertEqual(simplify_replacements([("Bert", "NewBert")]), [("Bert", "NewBert")])
222
+ self.assertEqual(
223
+ simplify_replacements([("Bert", "NewBert"), ("bert", "new-bert")]),
224
+ [("Bert", "NewBert"), ("bert", "new-bert")],
225
+ )
226
+ self.assertEqual(
227
+ simplify_replacements([("BertConfig", "NewBertConfig"), ("Bert", "NewBert"), ("bert", "new-bert")]),
228
+ [("Bert", "NewBert"), ("bert", "new-bert")],
229
+ )
230
+
231
+ def test_replace_model_patterns(self):
232
+ bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
233
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
234
+ bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
235
+ """
236
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
237
+ models.
238
+ """
239
+
240
+ config_class = BertConfig
241
+ load_tf_weights = load_tf_weights_in_bert
242
+ base_model_prefix = "bert"
243
+ is_parallelizable = True
244
+ supports_gradient_checkpointing = True
245
+ model_type = "bert"
246
+
247
+ BERT_CONSTANT = "value"
248
+ '''
249
+ bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
250
+ """
251
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
252
+ models.
253
+ """
254
+
255
+ config_class = NewBertConfig
256
+ load_tf_weights = load_tf_weights_in_new_bert
257
+ base_model_prefix = "new_bert"
258
+ is_parallelizable = True
259
+ supports_gradient_checkpointing = True
260
+ model_type = "new-bert"
261
+
262
+ NEW_BERT_CONSTANT = "value"
263
+ '''
264
+
265
+ bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
266
+ self.assertEqual(bert_converted, bert_expected)
267
+ # Replacements are empty here since bert as been replaced by bert_new in some instances and bert-new
268
+ # in others.
269
+ self.assertEqual(replacements, "")
270
+
271
+ # If we remove the model type, we will get replacements
272
+ bert_test = bert_test.replace(' model_type = "bert"\n', "")
273
+ bert_expected = bert_expected.replace(' model_type = "new-bert"\n', "")
274
+ bert_converted, replacements = replace_model_patterns(bert_test, bert_model_patterns, new_bert_model_patterns)
275
+ self.assertEqual(bert_converted, bert_expected)
276
+ self.assertEqual(replacements, "BERT->NEW_BERT,Bert->NewBert,bert->new_bert")
277
+
278
+ gpt_model_patterns = ModelPatterns("GPT2", "gpt2")
279
+ new_gpt_model_patterns = ModelPatterns("GPT-New new", "huggingface/gpt-new-base")
280
+ gpt_test = '''class GPT2PreTrainedModel(PreTrainedModel):
281
+ """
282
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
283
+ models.
284
+ """
285
+
286
+ config_class = GPT2Config
287
+ load_tf_weights = load_tf_weights_in_gpt2
288
+ base_model_prefix = "transformer"
289
+ is_parallelizable = True
290
+ supports_gradient_checkpointing = True
291
+
292
+ GPT2_CONSTANT = "value"
293
+ '''
294
+
295
+ gpt_expected = '''class GPTNewNewPreTrainedModel(PreTrainedModel):
296
+ """
297
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
298
+ models.
299
+ """
300
+
301
+ config_class = GPTNewNewConfig
302
+ load_tf_weights = load_tf_weights_in_gpt_new_new
303
+ base_model_prefix = "transformer"
304
+ is_parallelizable = True
305
+ supports_gradient_checkpointing = True
306
+
307
+ GPT_NEW_NEW_CONSTANT = "value"
308
+ '''
309
+
310
+ gpt_converted, replacements = replace_model_patterns(gpt_test, gpt_model_patterns, new_gpt_model_patterns)
311
+ self.assertEqual(gpt_converted, gpt_expected)
312
+ # Replacements are empty here since GPT2 as been replaced by GPTNewNew in some instances and GPT_NEW_NEW
313
+ # in others.
314
+ self.assertEqual(replacements, "")
315
+
316
+ roberta_model_patterns = ModelPatterns("RoBERTa", "FacebookAI/roberta-base", model_camel_cased="Roberta")
317
+ new_roberta_model_patterns = ModelPatterns(
318
+ "RoBERTa-New", "huggingface/roberta-new-base", model_camel_cased="RobertaNew"
319
+ )
320
+ roberta_test = '''# Copied from transformers.models.bert.BertModel with Bert->Roberta
321
+ class RobertaModel(RobertaPreTrainedModel):
322
+ """ The base RoBERTa model. """
323
+ checkpoint = FacebookAI/roberta-base
324
+ base_model_prefix = "roberta"
325
+ '''
326
+ roberta_expected = '''# Copied from transformers.models.bert.BertModel with Bert->RobertaNew
327
+ class RobertaNewModel(RobertaNewPreTrainedModel):
328
+ """ The base RoBERTa-New model. """
329
+ checkpoint = huggingface/roberta-new-base
330
+ base_model_prefix = "roberta_new"
331
+ '''
332
+ roberta_converted, replacements = replace_model_patterns(
333
+ roberta_test, roberta_model_patterns, new_roberta_model_patterns
334
+ )
335
+ self.assertEqual(roberta_converted, roberta_expected)
336
+
337
+ def test_get_module_from_file(self):
338
+ self.assertEqual(
339
+ get_module_from_file("/git/transformers/src/transformers/models/bert/modeling_tf_bert.py"),
340
+ "transformers.models.bert.modeling_tf_bert",
341
+ )
342
+ self.assertEqual(
343
+ get_module_from_file("/transformers/models/gpt2/modeling_gpt2.py"),
344
+ "transformers.models.gpt2.modeling_gpt2",
345
+ )
346
+ with self.assertRaises(ValueError):
347
+ get_module_from_file("/models/gpt2/modeling_gpt2.py")
348
+
349
+ def test_duplicate_module(self):
350
+ bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
351
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
352
+ bert_test = '''class TFBertPreTrainedModel(PreTrainedModel):
353
+ """
354
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
355
+ models.
356
+ """
357
+
358
+ config_class = BertConfig
359
+ load_tf_weights = load_tf_weights_in_bert
360
+ base_model_prefix = "bert"
361
+ is_parallelizable = True
362
+ supports_gradient_checkpointing = True
363
+
364
+ BERT_CONSTANT = "value"
365
+ '''
366
+ bert_expected = '''class TFNewBertPreTrainedModel(PreTrainedModel):
367
+ """
368
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
369
+ models.
370
+ """
371
+
372
+ config_class = NewBertConfig
373
+ load_tf_weights = load_tf_weights_in_new_bert
374
+ base_model_prefix = "new_bert"
375
+ is_parallelizable = True
376
+ supports_gradient_checkpointing = True
377
+
378
+ NEW_BERT_CONSTANT = "value"
379
+ '''
380
+ bert_expected_with_copied_from = (
381
+ "# Copied from transformers.bert_module.TFBertPreTrainedModel with Bert->NewBert,bert->new_bert\n"
382
+ + bert_expected
383
+ )
384
+ with tempfile.TemporaryDirectory() as tmp_dir:
385
+ work_dir = os.path.join(tmp_dir, "transformers")
386
+ os.makedirs(work_dir)
387
+ file_name = os.path.join(work_dir, "bert_module.py")
388
+ dest_file_name = os.path.join(work_dir, "new_bert_module.py")
389
+
390
+ self.init_file(file_name, bert_test)
391
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
392
+ self.check_result(dest_file_name, bert_expected_with_copied_from)
393
+
394
+ self.init_file(file_name, bert_test)
395
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
396
+ self.check_result(dest_file_name, bert_expected)
397
+
398
+ def test_duplicate_module_with_copied_from(self):
399
+ bert_model_patterns = ModelPatterns("Bert", "google-bert/bert-base-cased")
400
+ new_bert_model_patterns = ModelPatterns("New Bert", "huggingface/bert-new-base")
401
+ bert_test = '''# Copied from transformers.models.xxx.XxxModel with Xxx->Bert
402
+ class TFBertPreTrainedModel(PreTrainedModel):
403
+ """
404
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
405
+ models.
406
+ """
407
+
408
+ config_class = BertConfig
409
+ load_tf_weights = load_tf_weights_in_bert
410
+ base_model_prefix = "bert"
411
+ is_parallelizable = True
412
+ supports_gradient_checkpointing = True
413
+
414
+ BERT_CONSTANT = "value"
415
+ '''
416
+ bert_expected = '''# Copied from transformers.models.xxx.XxxModel with Xxx->NewBert
417
+ class TFNewBertPreTrainedModel(PreTrainedModel):
418
+ """
419
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
420
+ models.
421
+ """
422
+
423
+ config_class = NewBertConfig
424
+ load_tf_weights = load_tf_weights_in_new_bert
425
+ base_model_prefix = "new_bert"
426
+ is_parallelizable = True
427
+ supports_gradient_checkpointing = True
428
+
429
+ NEW_BERT_CONSTANT = "value"
430
+ '''
431
+ with tempfile.TemporaryDirectory() as tmp_dir:
432
+ work_dir = os.path.join(tmp_dir, "transformers")
433
+ os.makedirs(work_dir)
434
+ file_name = os.path.join(work_dir, "bert_module.py")
435
+ dest_file_name = os.path.join(work_dir, "new_bert_module.py")
436
+
437
+ self.init_file(file_name, bert_test)
438
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns)
439
+ # There should not be a new Copied from statement, the old one should be adapted.
440
+ self.check_result(dest_file_name, bert_expected)
441
+
442
+ self.init_file(file_name, bert_test)
443
+ duplicate_module(file_name, bert_model_patterns, new_bert_model_patterns, add_copied_from=False)
444
+ self.check_result(dest_file_name, bert_expected)
445
+
446
+ def test_filter_framework_files(self):
447
+ files = ["modeling_bert.py", "modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"]
448
+ self.assertEqual(filter_framework_files(files), files)
449
+ self.assertEqual(set(filter_framework_files(files, ["pt", "tf", "flax"])), set(files))
450
+
451
+ self.assertEqual(set(filter_framework_files(files, ["pt"])), {"modeling_bert.py", "configuration_bert.py"})
452
+ self.assertEqual(set(filter_framework_files(files, ["tf"])), {"modeling_tf_bert.py", "configuration_bert.py"})
453
+ self.assertEqual(
454
+ set(filter_framework_files(files, ["flax"])), {"modeling_flax_bert.py", "configuration_bert.py"}
455
+ )
456
+
457
+ self.assertEqual(
458
+ set(filter_framework_files(files, ["pt", "tf"])),
459
+ {"modeling_tf_bert.py", "modeling_bert.py", "configuration_bert.py"},
460
+ )
461
+ self.assertEqual(
462
+ set(filter_framework_files(files, ["tf", "flax"])),
463
+ {"modeling_tf_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
464
+ )
465
+ self.assertEqual(
466
+ set(filter_framework_files(files, ["pt", "flax"])),
467
+ {"modeling_bert.py", "modeling_flax_bert.py", "configuration_bert.py"},
468
+ )
469
+
470
+ def test_get_model_files(self):
471
+ # BERT
472
+ bert_files = get_model_files("bert")
473
+
474
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
475
+ self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
476
+
477
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
478
+ self.assertEqual(model_files, BERT_MODEL_FILES)
479
+
480
+ self.assertEqual(bert_files["module_name"], "bert")
481
+
482
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
483
+ bert_test_files = {
484
+ "tests/models/bert/test_tokenization_bert.py",
485
+ "tests/models/bert/test_modeling_bert.py",
486
+ "tests/models/bert/test_modeling_tf_bert.py",
487
+ "tests/models/bert/test_modeling_flax_bert.py",
488
+ }
489
+ self.assertEqual(test_files, bert_test_files)
490
+
491
+ # VIT
492
+ vit_files = get_model_files("vit")
493
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
494
+ self.assertEqual(doc_file, "docs/source/en/model_doc/vit.md")
495
+
496
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
497
+ self.assertEqual(model_files, VIT_MODEL_FILES)
498
+
499
+ self.assertEqual(vit_files["module_name"], "vit")
500
+
501
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
502
+ vit_test_files = {
503
+ "tests/models/vit/test_image_processing_vit.py",
504
+ "tests/models/vit/test_modeling_vit.py",
505
+ "tests/models/vit/test_modeling_tf_vit.py",
506
+ "tests/models/vit/test_modeling_flax_vit.py",
507
+ }
508
+ self.assertEqual(test_files, vit_test_files)
509
+
510
+ # Wav2Vec2
511
+ wav2vec2_files = get_model_files("wav2vec2")
512
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
513
+ self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.md")
514
+
515
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
516
+ self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
517
+
518
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
519
+
520
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
521
+ wav2vec2_test_files = {
522
+ "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
523
+ "tests/models/wav2vec2/test_modeling_wav2vec2.py",
524
+ "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
525
+ "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
526
+ "tests/models/wav2vec2/test_processor_wav2vec2.py",
527
+ "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
528
+ }
529
+ self.assertEqual(test_files, wav2vec2_test_files)
530
+
531
+ def test_get_model_files_only_pt(self):
532
+ # BERT
533
+ bert_files = get_model_files("bert", frameworks=["pt"])
534
+
535
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
536
+ self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
537
+
538
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
539
+ bert_model_files = BERT_MODEL_FILES - {
540
+ "src/transformers/models/bert/modeling_tf_bert.py",
541
+ "src/transformers/models/bert/modeling_flax_bert.py",
542
+ }
543
+ self.assertEqual(model_files, bert_model_files)
544
+
545
+ self.assertEqual(bert_files["module_name"], "bert")
546
+
547
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
548
+ bert_test_files = {
549
+ "tests/models/bert/test_tokenization_bert.py",
550
+ "tests/models/bert/test_modeling_bert.py",
551
+ }
552
+ self.assertEqual(test_files, bert_test_files)
553
+
554
+ # VIT
555
+ vit_files = get_model_files("vit", frameworks=["pt"])
556
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
557
+ self.assertEqual(doc_file, "docs/source/en/model_doc/vit.md")
558
+
559
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
560
+ vit_model_files = VIT_MODEL_FILES - {
561
+ "src/transformers/models/vit/modeling_tf_vit.py",
562
+ "src/transformers/models/vit/modeling_flax_vit.py",
563
+ }
564
+ self.assertEqual(model_files, vit_model_files)
565
+
566
+ self.assertEqual(vit_files["module_name"], "vit")
567
+
568
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
569
+ vit_test_files = {
570
+ "tests/models/vit/test_image_processing_vit.py",
571
+ "tests/models/vit/test_modeling_vit.py",
572
+ }
573
+ self.assertEqual(test_files, vit_test_files)
574
+
575
+ # Wav2Vec2
576
+ wav2vec2_files = get_model_files("wav2vec2", frameworks=["pt"])
577
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
578
+ self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.md")
579
+
580
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
581
+ wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {
582
+ "src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py",
583
+ "src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py",
584
+ }
585
+ self.assertEqual(model_files, wav2vec2_model_files)
586
+
587
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
588
+
589
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
590
+ wav2vec2_test_files = {
591
+ "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
592
+ "tests/models/wav2vec2/test_modeling_wav2vec2.py",
593
+ "tests/models/wav2vec2/test_processor_wav2vec2.py",
594
+ "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
595
+ }
596
+ self.assertEqual(test_files, wav2vec2_test_files)
597
+
598
+ def test_get_model_files_tf_and_flax(self):
599
+ # BERT
600
+ bert_files = get_model_files("bert", frameworks=["tf", "flax"])
601
+
602
+ doc_file = str(Path(bert_files["doc_file"]).relative_to(REPO_PATH))
603
+ self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
604
+
605
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["model_files"]}
606
+ bert_model_files = BERT_MODEL_FILES - {"src/transformers/models/bert/modeling_bert.py"}
607
+ self.assertEqual(model_files, bert_model_files)
608
+
609
+ self.assertEqual(bert_files["module_name"], "bert")
610
+
611
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in bert_files["test_files"]}
612
+ bert_test_files = {
613
+ "tests/models/bert/test_tokenization_bert.py",
614
+ "tests/models/bert/test_modeling_tf_bert.py",
615
+ "tests/models/bert/test_modeling_flax_bert.py",
616
+ }
617
+ self.assertEqual(test_files, bert_test_files)
618
+
619
+ # VIT
620
+ vit_files = get_model_files("vit", frameworks=["tf", "flax"])
621
+ doc_file = str(Path(vit_files["doc_file"]).relative_to(REPO_PATH))
622
+ self.assertEqual(doc_file, "docs/source/en/model_doc/vit.md")
623
+
624
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["model_files"]}
625
+ vit_model_files = VIT_MODEL_FILES - {"src/transformers/models/vit/modeling_vit.py"}
626
+ self.assertEqual(model_files, vit_model_files)
627
+
628
+ self.assertEqual(vit_files["module_name"], "vit")
629
+
630
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in vit_files["test_files"]}
631
+ vit_test_files = {
632
+ "tests/models/vit/test_image_processing_vit.py",
633
+ "tests/models/vit/test_modeling_tf_vit.py",
634
+ "tests/models/vit/test_modeling_flax_vit.py",
635
+ }
636
+ self.assertEqual(test_files, vit_test_files)
637
+
638
+ # Wav2Vec2
639
+ wav2vec2_files = get_model_files("wav2vec2", frameworks=["tf", "flax"])
640
+ doc_file = str(Path(wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
641
+ self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.md")
642
+
643
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["model_files"]}
644
+ wav2vec2_model_files = WAV2VEC2_MODEL_FILES - {"src/transformers/models/wav2vec2/modeling_wav2vec2.py"}
645
+ self.assertEqual(model_files, wav2vec2_model_files)
646
+
647
+ self.assertEqual(wav2vec2_files["module_name"], "wav2vec2")
648
+
649
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in wav2vec2_files["test_files"]}
650
+ wav2vec2_test_files = {
651
+ "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
652
+ "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
653
+ "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
654
+ "tests/models/wav2vec2/test_processor_wav2vec2.py",
655
+ "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
656
+ }
657
+ self.assertEqual(test_files, wav2vec2_test_files)
658
+
659
+ def test_find_base_model_checkpoint(self):
660
+ self.assertEqual(find_base_model_checkpoint("bert"), "google-bert/bert-base-uncased")
661
+ self.assertEqual(find_base_model_checkpoint("gpt2"), "openai-community/gpt2")
662
+
663
+ def test_retrieve_model_classes(self):
664
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2").items()}
665
+ expected_gpt_classes = {
666
+ "pt": {
667
+ "GPT2ForTokenClassification",
668
+ "GPT2Model",
669
+ "GPT2LMHeadModel",
670
+ "GPT2ForSequenceClassification",
671
+ "GPT2ForQuestionAnswering",
672
+ },
673
+ "tf": {"TFGPT2Model", "TFGPT2ForSequenceClassification", "TFGPT2LMHeadModel"},
674
+ "flax": {"FlaxGPT2Model", "FlaxGPT2LMHeadModel"},
675
+ }
676
+ self.assertEqual(gpt_classes, expected_gpt_classes)
677
+
678
+ del expected_gpt_classes["flax"]
679
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["pt", "tf"]).items()}
680
+ self.assertEqual(gpt_classes, expected_gpt_classes)
681
+
682
+ del expected_gpt_classes["pt"]
683
+ gpt_classes = {k: set(v) for k, v in retrieve_model_classes("gpt2", frameworks=["tf"]).items()}
684
+ self.assertEqual(gpt_classes, expected_gpt_classes)
685
+
686
+ def test_retrieve_info_for_model_with_bert(self):
687
+ bert_info = retrieve_info_for_model("bert")
688
+ bert_classes = [
689
+ "BertForTokenClassification",
690
+ "BertForQuestionAnswering",
691
+ "BertForNextSentencePrediction",
692
+ "BertForSequenceClassification",
693
+ "BertForMaskedLM",
694
+ "BertForMultipleChoice",
695
+ "BertModel",
696
+ "BertForPreTraining",
697
+ "BertLMHeadModel",
698
+ ]
699
+ expected_model_classes = {
700
+ "pt": set(bert_classes),
701
+ "tf": {f"TF{m}" for m in bert_classes},
702
+ "flax": {f"Flax{m}" for m in bert_classes[:-1] + ["BertForCausalLM"]},
703
+ }
704
+
705
+ self.assertEqual(set(bert_info["frameworks"]), {"pt", "tf", "flax"})
706
+ model_classes = {k: set(v) for k, v in bert_info["model_classes"].items()}
707
+ self.assertEqual(model_classes, expected_model_classes)
708
+
709
+ all_bert_files = bert_info["model_files"]
710
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["model_files"]}
711
+ self.assertEqual(model_files, BERT_MODEL_FILES)
712
+
713
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_bert_files["test_files"]}
714
+ bert_test_files = {
715
+ "tests/models/bert/test_tokenization_bert.py",
716
+ "tests/models/bert/test_modeling_bert.py",
717
+ "tests/models/bert/test_modeling_tf_bert.py",
718
+ "tests/models/bert/test_modeling_flax_bert.py",
719
+ }
720
+ self.assertEqual(test_files, bert_test_files)
721
+
722
+ doc_file = str(Path(all_bert_files["doc_file"]).relative_to(REPO_PATH))
723
+ self.assertEqual(doc_file, "docs/source/en/model_doc/bert.md")
724
+
725
+ self.assertEqual(all_bert_files["module_name"], "bert")
726
+
727
+ bert_model_patterns = bert_info["model_patterns"]
728
+ self.assertEqual(bert_model_patterns.model_name, "BERT")
729
+ self.assertEqual(bert_model_patterns.checkpoint, "google-bert/bert-base-uncased")
730
+ self.assertEqual(bert_model_patterns.model_type, "bert")
731
+ self.assertEqual(bert_model_patterns.model_lower_cased, "bert")
732
+ self.assertEqual(bert_model_patterns.model_camel_cased, "Bert")
733
+ self.assertEqual(bert_model_patterns.model_upper_cased, "BERT")
734
+ self.assertEqual(bert_model_patterns.config_class, "BertConfig")
735
+ self.assertEqual(bert_model_patterns.tokenizer_class, "BertTokenizer")
736
+ self.assertIsNone(bert_model_patterns.feature_extractor_class)
737
+ self.assertIsNone(bert_model_patterns.processor_class)
738
+
739
+ def test_retrieve_info_for_model_with_vit(self):
740
+ vit_info = retrieve_info_for_model("vit")
741
+ vit_classes = ["ViTForImageClassification", "ViTModel"]
742
+ pt_only_classes = ["ViTForMaskedImageModeling"]
743
+ expected_model_classes = {
744
+ "pt": set(vit_classes + pt_only_classes),
745
+ "tf": {f"TF{m}" for m in vit_classes},
746
+ "flax": {f"Flax{m}" for m in vit_classes},
747
+ }
748
+
749
+ self.assertEqual(set(vit_info["frameworks"]), {"pt", "tf", "flax"})
750
+ model_classes = {k: set(v) for k, v in vit_info["model_classes"].items()}
751
+ self.assertEqual(model_classes, expected_model_classes)
752
+
753
+ all_vit_files = vit_info["model_files"]
754
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["model_files"]}
755
+ self.assertEqual(model_files, VIT_MODEL_FILES)
756
+
757
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_vit_files["test_files"]}
758
+ vit_test_files = {
759
+ "tests/models/vit/test_image_processing_vit.py",
760
+ "tests/models/vit/test_modeling_vit.py",
761
+ "tests/models/vit/test_modeling_tf_vit.py",
762
+ "tests/models/vit/test_modeling_flax_vit.py",
763
+ }
764
+ self.assertEqual(test_files, vit_test_files)
765
+
766
+ doc_file = str(Path(all_vit_files["doc_file"]).relative_to(REPO_PATH))
767
+ self.assertEqual(doc_file, "docs/source/en/model_doc/vit.md")
768
+
769
+ self.assertEqual(all_vit_files["module_name"], "vit")
770
+
771
+ vit_model_patterns = vit_info["model_patterns"]
772
+ self.assertEqual(vit_model_patterns.model_name, "ViT")
773
+ self.assertEqual(vit_model_patterns.checkpoint, "google/vit-base-patch16-224-in21k")
774
+ self.assertEqual(vit_model_patterns.model_type, "vit")
775
+ self.assertEqual(vit_model_patterns.model_lower_cased, "vit")
776
+ self.assertEqual(vit_model_patterns.model_camel_cased, "ViT")
777
+ self.assertEqual(vit_model_patterns.model_upper_cased, "VIT")
778
+ self.assertEqual(vit_model_patterns.config_class, "ViTConfig")
779
+ self.assertEqual(vit_model_patterns.feature_extractor_class, "ViTFeatureExtractor")
780
+ self.assertEqual(vit_model_patterns.image_processor_class, "ViTImageProcessor")
781
+ self.assertIsNone(vit_model_patterns.tokenizer_class)
782
+ self.assertIsNone(vit_model_patterns.processor_class)
783
+
784
+ def test_retrieve_info_for_model_with_wav2vec2(self):
785
+ wav2vec2_info = retrieve_info_for_model("wav2vec2")
786
+ wav2vec2_classes = [
787
+ "Wav2Vec2Model",
788
+ "Wav2Vec2ForPreTraining",
789
+ "Wav2Vec2ForAudioFrameClassification",
790
+ "Wav2Vec2ForCTC",
791
+ "Wav2Vec2ForMaskedLM",
792
+ "Wav2Vec2ForSequenceClassification",
793
+ "Wav2Vec2ForXVector",
794
+ ]
795
+ expected_model_classes = {
796
+ "pt": set(wav2vec2_classes),
797
+ "tf": {f"TF{m}" for m in [wav2vec2_classes[0], wav2vec2_classes[-2]]},
798
+ "flax": {f"Flax{m}" for m in wav2vec2_classes[:2]},
799
+ }
800
+
801
+ self.assertEqual(set(wav2vec2_info["frameworks"]), {"pt", "tf", "flax"})
802
+ model_classes = {k: set(v) for k, v in wav2vec2_info["model_classes"].items()}
803
+ self.assertEqual(model_classes, expected_model_classes)
804
+
805
+ all_wav2vec2_files = wav2vec2_info["model_files"]
806
+ model_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["model_files"]}
807
+ self.assertEqual(model_files, WAV2VEC2_MODEL_FILES)
808
+
809
+ test_files = {str(Path(f).relative_to(REPO_PATH)) for f in all_wav2vec2_files["test_files"]}
810
+ wav2vec2_test_files = {
811
+ "tests/models/wav2vec2/test_feature_extraction_wav2vec2.py",
812
+ "tests/models/wav2vec2/test_modeling_wav2vec2.py",
813
+ "tests/models/wav2vec2/test_modeling_tf_wav2vec2.py",
814
+ "tests/models/wav2vec2/test_modeling_flax_wav2vec2.py",
815
+ "tests/models/wav2vec2/test_processor_wav2vec2.py",
816
+ "tests/models/wav2vec2/test_tokenization_wav2vec2.py",
817
+ }
818
+ self.assertEqual(test_files, wav2vec2_test_files)
819
+
820
+ doc_file = str(Path(all_wav2vec2_files["doc_file"]).relative_to(REPO_PATH))
821
+ self.assertEqual(doc_file, "docs/source/en/model_doc/wav2vec2.md")
822
+
823
+ self.assertEqual(all_wav2vec2_files["module_name"], "wav2vec2")
824
+
825
+ wav2vec2_model_patterns = wav2vec2_info["model_patterns"]
826
+ self.assertEqual(wav2vec2_model_patterns.model_name, "Wav2Vec2")
827
+ self.assertEqual(wav2vec2_model_patterns.checkpoint, "facebook/wav2vec2-base-960h")
828
+ self.assertEqual(wav2vec2_model_patterns.model_type, "wav2vec2")
829
+ self.assertEqual(wav2vec2_model_patterns.model_lower_cased, "wav2vec2")
830
+ self.assertEqual(wav2vec2_model_patterns.model_camel_cased, "Wav2Vec2")
831
+ self.assertEqual(wav2vec2_model_patterns.model_upper_cased, "WAV2VEC2")
832
+ self.assertEqual(wav2vec2_model_patterns.config_class, "Wav2Vec2Config")
833
+ self.assertEqual(wav2vec2_model_patterns.feature_extractor_class, "Wav2Vec2FeatureExtractor")
834
+ self.assertEqual(wav2vec2_model_patterns.processor_class, "Wav2Vec2Processor")
835
+ self.assertEqual(wav2vec2_model_patterns.tokenizer_class, "Wav2Vec2CTCTokenizer")
836
+
837
+ def test_clean_frameworks_in_init_with_gpt(self):
838
+ test_init = """
839
+ from typing import TYPE_CHECKING
840
+
841
+ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
842
+
843
+ _import_structure = {
844
+ "configuration_gpt2": ["GPT2Config", "GPT2OnnxConfig"],
845
+ "tokenization_gpt2": ["GPT2Tokenizer"],
846
+ }
847
+
848
+ try:
849
+ if not is_tokenizers_available():
850
+ raise OptionalDependencyNotAvailable()
851
+ except OptionalDependencyNotAvailable:
852
+ pass
853
+ else:
854
+ _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
855
+
856
+ try:
857
+ if not is_torch_available():
858
+ raise OptionalDependencyNotAvailable()
859
+ except OptionalDependencyNotAvailable:
860
+ pass
861
+ else:
862
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
863
+
864
+ try:
865
+ if not is_tf_available():
866
+ raise OptionalDependencyNotAvailable()
867
+ except OptionalDependencyNotAvailable:
868
+ pass
869
+ else:
870
+ _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
871
+
872
+ try:
873
+ if not is_flax_available():
874
+ raise OptionalDependencyNotAvailable()
875
+ except OptionalDependencyNotAvailable:
876
+ pass
877
+ else:
878
+ _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
879
+
880
+ if TYPE_CHECKING:
881
+ from .configuration_gpt2 import GPT2Config, GPT2OnnxConfig
882
+ from .tokenization_gpt2 import GPT2Tokenizer
883
+
884
+ try:
885
+ if not is_tokenizers_available():
886
+ raise OptionalDependencyNotAvailable()
887
+ except OptionalDependencyNotAvailable:
888
+ pass
889
+ else:
890
+ from .tokenization_gpt2_fast import GPT2TokenizerFast
891
+
892
+ try:
893
+ if not is_torch_available():
894
+ raise OptionalDependencyNotAvailable()
895
+ except OptionalDependencyNotAvailable:
896
+ pass
897
+ else:
898
+ from .modeling_gpt2 import GPT2Model
899
+
900
+ try:
901
+ if not is_tf_available():
902
+ raise OptionalDependencyNotAvailable()
903
+ except OptionalDependencyNotAvailable:
904
+ pass
905
+ else:
906
+ from .modeling_tf_gpt2 import TFGPT2Model
907
+
908
+ try:
909
+ if not is_flax_available():
910
+ raise OptionalDependencyNotAvailable()
911
+ except OptionalDependencyNotAvailable:
912
+ pass
913
+ else:
914
+ from .modeling_flax_gpt2 import FlaxGPT2Model
915
+
916
+ else:
917
+ import sys
918
+
919
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
920
+ """
921
+
922
+ init_no_tokenizer = """
923
+ from typing import TYPE_CHECKING
924
+
925
+ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
926
+
927
+ _import_structure = {
928
+ "configuration_gpt2": ["GPT2Config", "GPT2OnnxConfig"],
929
+ }
930
+
931
+ try:
932
+ if not is_torch_available():
933
+ raise OptionalDependencyNotAvailable()
934
+ except OptionalDependencyNotAvailable:
935
+ pass
936
+ else:
937
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
938
+
939
+ try:
940
+ if not is_tf_available():
941
+ raise OptionalDependencyNotAvailable()
942
+ except OptionalDependencyNotAvailable:
943
+ pass
944
+ else:
945
+ _import_structure["modeling_tf_gpt2"] = ["TFGPT2Model"]
946
+
947
+ try:
948
+ if not is_flax_available():
949
+ raise OptionalDependencyNotAvailable()
950
+ except OptionalDependencyNotAvailable:
951
+ pass
952
+ else:
953
+ _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2Model"]
954
+
955
+ if TYPE_CHECKING:
956
+ from .configuration_gpt2 import GPT2Config, GPT2OnnxConfig
957
+
958
+ try:
959
+ if not is_torch_available():
960
+ raise OptionalDependencyNotAvailable()
961
+ except OptionalDependencyNotAvailable:
962
+ pass
963
+ else:
964
+ from .modeling_gpt2 import GPT2Model
965
+
966
+ try:
967
+ if not is_tf_available():
968
+ raise OptionalDependencyNotAvailable()
969
+ except OptionalDependencyNotAvailable:
970
+ pass
971
+ else:
972
+ from .modeling_tf_gpt2 import TFGPT2Model
973
+
974
+ try:
975
+ if not is_flax_available():
976
+ raise OptionalDependencyNotAvailable()
977
+ except OptionalDependencyNotAvailable:
978
+ pass
979
+ else:
980
+ from .modeling_flax_gpt2 import FlaxGPT2Model
981
+
982
+ else:
983
+ import sys
984
+
985
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
986
+ """
987
+
988
+ init_pt_only = """
989
+ from typing import TYPE_CHECKING
990
+
991
+ from ...utils import _LazyModule, is_tokenizers_available, is_torch_available
992
+
993
+ _import_structure = {
994
+ "configuration_gpt2": ["GPT2Config", "GPT2OnnxConfig"],
995
+ "tokenization_gpt2": ["GPT2Tokenizer"],
996
+ }
997
+
998
+ try:
999
+ if not is_tokenizers_available():
1000
+ raise OptionalDependencyNotAvailable()
1001
+ except OptionalDependencyNotAvailable:
1002
+ pass
1003
+ else:
1004
+ _import_structure["tokenization_gpt2_fast"] = ["GPT2TokenizerFast"]
1005
+
1006
+ try:
1007
+ if not is_torch_available():
1008
+ raise OptionalDependencyNotAvailable()
1009
+ except OptionalDependencyNotAvailable:
1010
+ pass
1011
+ else:
1012
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
1013
+
1014
+ if TYPE_CHECKING:
1015
+ from .configuration_gpt2 import GPT2Config, GPT2OnnxConfig
1016
+ from .tokenization_gpt2 import GPT2Tokenizer
1017
+
1018
+ try:
1019
+ if not is_tokenizers_available():
1020
+ raise OptionalDependencyNotAvailable()
1021
+ except OptionalDependencyNotAvailable:
1022
+ pass
1023
+ else:
1024
+ from .tokenization_gpt2_fast import GPT2TokenizerFast
1025
+
1026
+ try:
1027
+ if not is_torch_available():
1028
+ raise OptionalDependencyNotAvailable()
1029
+ except OptionalDependencyNotAvailable:
1030
+ pass
1031
+ else:
1032
+ from .modeling_gpt2 import GPT2Model
1033
+
1034
+ else:
1035
+ import sys
1036
+
1037
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1038
+ """
1039
+
1040
+ init_pt_only_no_tokenizer = """
1041
+ from typing import TYPE_CHECKING
1042
+
1043
+ from ...utils import _LazyModule, is_torch_available
1044
+
1045
+ _import_structure = {
1046
+ "configuration_gpt2": ["GPT2Config", "GPT2OnnxConfig"],
1047
+ }
1048
+
1049
+ try:
1050
+ if not is_torch_available():
1051
+ raise OptionalDependencyNotAvailable()
1052
+ except OptionalDependencyNotAvailable:
1053
+ pass
1054
+ else:
1055
+ _import_structure["modeling_gpt2"] = ["GPT2Model"]
1056
+
1057
+ if TYPE_CHECKING:
1058
+ from .configuration_gpt2 import GPT2Config, GPT2OnnxConfig
1059
+
1060
+ try:
1061
+ if not is_torch_available():
1062
+ raise OptionalDependencyNotAvailable()
1063
+ except OptionalDependencyNotAvailable:
1064
+ pass
1065
+ else:
1066
+ from .modeling_gpt2 import GPT2Model
1067
+
1068
+ else:
1069
+ import sys
1070
+
1071
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1072
+ """
1073
+
1074
+ with tempfile.TemporaryDirectory() as tmp_dir:
1075
+ file_name = os.path.join(tmp_dir, "../__init__.py")
1076
+
1077
+ self.init_file(file_name, test_init)
1078
+ clean_frameworks_in_init(file_name, keep_processing=False)
1079
+ self.check_result(file_name, init_no_tokenizer)
1080
+
1081
+ self.init_file(file_name, test_init)
1082
+ clean_frameworks_in_init(file_name, frameworks=["pt"])
1083
+ self.check_result(file_name, init_pt_only)
1084
+
1085
+ self.init_file(file_name, test_init)
1086
+ clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
1087
+ self.check_result(file_name, init_pt_only_no_tokenizer)
1088
+
1089
+ def test_clean_frameworks_in_init_with_vit(self):
1090
+ test_init = """
1091
+ from typing import TYPE_CHECKING
1092
+
1093
+ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
1094
+
1095
+ _import_structure = {
1096
+ "configuration_vit": ["ViTConfig"],
1097
+ }
1098
+
1099
+ try:
1100
+ if not is_vision_available():
1101
+ raise OptionalDependencyNotAvailable()
1102
+ except OptionalDependencyNotAvailable:
1103
+ pass
1104
+ else:
1105
+ _import_structure["image_processing_vit"] = ["ViTImageProcessor"]
1106
+
1107
+ try:
1108
+ if not is_torch_available():
1109
+ raise OptionalDependencyNotAvailable()
1110
+ except OptionalDependencyNotAvailable:
1111
+ pass
1112
+ else:
1113
+ _import_structure["modeling_vit"] = ["ViTModel"]
1114
+
1115
+ try:
1116
+ if not is_tf_available():
1117
+ raise OptionalDependencyNotAvailable()
1118
+ except OptionalDependencyNotAvailable:
1119
+ pass
1120
+ else:
1121
+ _import_structure["modeling_tf_vit"] = ["TFViTModel"]
1122
+
1123
+ try:
1124
+ if not is_flax_available():
1125
+ raise OptionalDependencyNotAvailable()
1126
+ except OptionalDependencyNotAvailable:
1127
+ pass
1128
+ else:
1129
+ _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
1130
+
1131
+ if TYPE_CHECKING:
1132
+ from .configuration_vit import ViTConfig
1133
+
1134
+ try:
1135
+ if not is_vision_available():
1136
+ raise OptionalDependencyNotAvailable()
1137
+ except OptionalDependencyNotAvailable:
1138
+ pass
1139
+ else:
1140
+ from .image_processing_vit import ViTImageProcessor
1141
+
1142
+ try:
1143
+ if not is_torch_available():
1144
+ raise OptionalDependencyNotAvailable()
1145
+ except OptionalDependencyNotAvailable:
1146
+ pass
1147
+ else:
1148
+ from .modeling_vit import ViTModel
1149
+
1150
+ try:
1151
+ if not is_tf_available():
1152
+ raise OptionalDependencyNotAvailable()
1153
+ except OptionalDependencyNotAvailable:
1154
+ pass
1155
+ else:
1156
+ from .modeling_tf_vit import TFViTModel
1157
+
1158
+ try:
1159
+ if not is_flax_available():
1160
+ raise OptionalDependencyNotAvailable()
1161
+ except OptionalDependencyNotAvailable:
1162
+ pass
1163
+ else:
1164
+ from .modeling_flax_vit import FlaxViTModel
1165
+
1166
+ else:
1167
+ import sys
1168
+
1169
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1170
+ """
1171
+
1172
+ init_no_feature_extractor = """
1173
+ from typing import TYPE_CHECKING
1174
+
1175
+ from ...utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
1176
+
1177
+ _import_structure = {
1178
+ "configuration_vit": ["ViTConfig"],
1179
+ }
1180
+
1181
+ try:
1182
+ if not is_torch_available():
1183
+ raise OptionalDependencyNotAvailable()
1184
+ except OptionalDependencyNotAvailable:
1185
+ pass
1186
+ else:
1187
+ _import_structure["modeling_vit"] = ["ViTModel"]
1188
+
1189
+ try:
1190
+ if not is_tf_available():
1191
+ raise OptionalDependencyNotAvailable()
1192
+ except OptionalDependencyNotAvailable:
1193
+ pass
1194
+ else:
1195
+ _import_structure["modeling_tf_vit"] = ["TFViTModel"]
1196
+
1197
+ try:
1198
+ if not is_flax_available():
1199
+ raise OptionalDependencyNotAvailable()
1200
+ except OptionalDependencyNotAvailable:
1201
+ pass
1202
+ else:
1203
+ _import_structure["modeling_flax_vit"] = ["FlaxViTModel"]
1204
+
1205
+ if TYPE_CHECKING:
1206
+ from .configuration_vit import ViTConfig
1207
+
1208
+ try:
1209
+ if not is_torch_available():
1210
+ raise OptionalDependencyNotAvailable()
1211
+ except OptionalDependencyNotAvailable:
1212
+ pass
1213
+ else:
1214
+ from .modeling_vit import ViTModel
1215
+
1216
+ try:
1217
+ if not is_tf_available():
1218
+ raise OptionalDependencyNotAvailable()
1219
+ except OptionalDependencyNotAvailable:
1220
+ pass
1221
+ else:
1222
+ from .modeling_tf_vit import TFViTModel
1223
+
1224
+ try:
1225
+ if not is_flax_available():
1226
+ raise OptionalDependencyNotAvailable()
1227
+ except OptionalDependencyNotAvailable:
1228
+ pass
1229
+ else:
1230
+ from .modeling_flax_vit import FlaxViTModel
1231
+
1232
+ else:
1233
+ import sys
1234
+
1235
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1236
+ """
1237
+
1238
+ init_pt_only = """
1239
+ from typing import TYPE_CHECKING
1240
+
1241
+ from ...utils import _LazyModule, is_torch_available, is_vision_available
1242
+
1243
+ _import_structure = {
1244
+ "configuration_vit": ["ViTConfig"],
1245
+ }
1246
+
1247
+ try:
1248
+ if not is_vision_available():
1249
+ raise OptionalDependencyNotAvailable()
1250
+ except OptionalDependencyNotAvailable:
1251
+ pass
1252
+ else:
1253
+ _import_structure["image_processing_vit"] = ["ViTImageProcessor"]
1254
+
1255
+ try:
1256
+ if not is_torch_available():
1257
+ raise OptionalDependencyNotAvailable()
1258
+ except OptionalDependencyNotAvailable:
1259
+ pass
1260
+ else:
1261
+ _import_structure["modeling_vit"] = ["ViTModel"]
1262
+
1263
+ if TYPE_CHECKING:
1264
+ from .configuration_vit import ViTConfig
1265
+
1266
+ try:
1267
+ if not is_vision_available():
1268
+ raise OptionalDependencyNotAvailable()
1269
+ except OptionalDependencyNotAvailable:
1270
+ pass
1271
+ else:
1272
+ from .image_processing_vit import ViTImageProcessor
1273
+
1274
+ try:
1275
+ if not is_torch_available():
1276
+ raise OptionalDependencyNotAvailable()
1277
+ except OptionalDependencyNotAvailable:
1278
+ pass
1279
+ else:
1280
+ from .modeling_vit import ViTModel
1281
+
1282
+ else:
1283
+ import sys
1284
+
1285
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1286
+ """
1287
+
1288
+ init_pt_only_no_feature_extractor = """
1289
+ from typing import TYPE_CHECKING
1290
+
1291
+ from ...utils import _LazyModule, is_torch_available
1292
+
1293
+ _import_structure = {
1294
+ "configuration_vit": ["ViTConfig"],
1295
+ }
1296
+
1297
+ try:
1298
+ if not is_torch_available():
1299
+ raise OptionalDependencyNotAvailable()
1300
+ except OptionalDependencyNotAvailable:
1301
+ pass
1302
+ else:
1303
+ _import_structure["modeling_vit"] = ["ViTModel"]
1304
+
1305
+ if TYPE_CHECKING:
1306
+ from .configuration_vit import ViTConfig
1307
+
1308
+ try:
1309
+ if not is_torch_available():
1310
+ raise OptionalDependencyNotAvailable()
1311
+ except OptionalDependencyNotAvailable:
1312
+ pass
1313
+ else:
1314
+ from .modeling_vit import ViTModel
1315
+
1316
+ else:
1317
+ import sys
1318
+
1319
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
1320
+ """
1321
+
1322
+ with tempfile.TemporaryDirectory() as tmp_dir:
1323
+ file_name = os.path.join(tmp_dir, "../__init__.py")
1324
+
1325
+ self.init_file(file_name, test_init)
1326
+ clean_frameworks_in_init(file_name, keep_processing=False)
1327
+ self.check_result(file_name, init_no_feature_extractor)
1328
+
1329
+ self.init_file(file_name, test_init)
1330
+ clean_frameworks_in_init(file_name, frameworks=["pt"])
1331
+ self.check_result(file_name, init_pt_only)
1332
+
1333
+ self.init_file(file_name, test_init)
1334
+ clean_frameworks_in_init(file_name, frameworks=["pt"], keep_processing=False)
1335
+ self.check_result(file_name, init_pt_only_no_feature_extractor)
1336
+
1337
+ def test_duplicate_doc_file(self):
1338
+ test_doc = """
1339
+ # GPT2
1340
+
1341
+ ## Overview
1342
+
1343
+ Overview of the model.
1344
+
1345
+ ## GPT2Config
1346
+
1347
+ [[autodoc]] GPT2Config
1348
+
1349
+ ## GPT2Tokenizer
1350
+
1351
+ [[autodoc]] GPT2Tokenizer
1352
+ - save_vocabulary
1353
+
1354
+ ## GPT2TokenizerFast
1355
+
1356
+ [[autodoc]] GPT2TokenizerFast
1357
+
1358
+ ## GPT2 specific outputs
1359
+
1360
+ [[autodoc]] models.gpt2.modeling_gpt2.GPT2DoubleHeadsModelOutput
1361
+
1362
+ [[autodoc]] models.gpt2.modeling_tf_gpt2.TFGPT2DoubleHeadsModelOutput
1363
+
1364
+ ## GPT2Model
1365
+
1366
+ [[autodoc]] GPT2Model
1367
+ - forward
1368
+
1369
+ ## TFGPT2Model
1370
+
1371
+ [[autodoc]] TFGPT2Model
1372
+ - call
1373
+
1374
+ ## FlaxGPT2Model
1375
+
1376
+ [[autodoc]] FlaxGPT2Model
1377
+ - __call__
1378
+
1379
+ """
1380
+ test_new_doc = """
1381
+ # GPT-New New
1382
+
1383
+ ## Overview
1384
+
1385
+ The GPT-New New model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
1386
+ <INSERT SHORT SUMMARY HERE>
1387
+
1388
+ The abstract from the paper is the following:
1389
+
1390
+ *<INSERT PAPER ABSTRACT HERE>*
1391
+
1392
+ Tips:
1393
+
1394
+ <INSERT TIPS ABOUT MODEL HERE>
1395
+
1396
+ This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
1397
+ The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
1398
+
1399
+
1400
+ ## GPTNewNewConfig
1401
+
1402
+ [[autodoc]] GPTNewNewConfig
1403
+
1404
+ ## GPTNewNewTokenizer
1405
+
1406
+ [[autodoc]] GPTNewNewTokenizer
1407
+ - save_vocabulary
1408
+
1409
+ ## GPTNewNewTokenizerFast
1410
+
1411
+ [[autodoc]] GPTNewNewTokenizerFast
1412
+
1413
+ ## GPTNewNew specific outputs
1414
+
1415
+ [[autodoc]] models.gpt_new_new.modeling_gpt_new_new.GPTNewNewDoubleHeadsModelOutput
1416
+
1417
+ [[autodoc]] models.gpt_new_new.modeling_tf_gpt_new_new.TFGPTNewNewDoubleHeadsModelOutput
1418
+
1419
+ ## GPTNewNewModel
1420
+
1421
+ [[autodoc]] GPTNewNewModel
1422
+ - forward
1423
+
1424
+ ## TFGPTNewNewModel
1425
+
1426
+ [[autodoc]] TFGPTNewNewModel
1427
+ - call
1428
+
1429
+ ## FlaxGPTNewNewModel
1430
+
1431
+ [[autodoc]] FlaxGPTNewNewModel
1432
+ - __call__
1433
+
1434
+ """
1435
+
1436
+ with tempfile.TemporaryDirectory() as tmp_dir:
1437
+ doc_file = os.path.join(tmp_dir, "gpt2.md")
1438
+ new_doc_file = os.path.join(tmp_dir, "gpt-new-new.md")
1439
+
1440
+ gpt2_model_patterns = ModelPatterns("GPT2", "gpt2", tokenizer_class="GPT2Tokenizer")
1441
+ new_model_patterns = ModelPatterns(
1442
+ "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPTNewNewTokenizer"
1443
+ )
1444
+
1445
+ self.init_file(doc_file, test_doc)
1446
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
1447
+ self.check_result(new_doc_file, test_new_doc)
1448
+
1449
+ test_new_doc_pt_only = test_new_doc.replace(
1450
+ """
1451
+ ## TFGPTNewNewModel
1452
+
1453
+ [[autodoc]] TFGPTNewNewModel
1454
+ - call
1455
+
1456
+ ## FlaxGPTNewNewModel
1457
+
1458
+ [[autodoc]] FlaxGPTNewNewModel
1459
+ - __call__
1460
+
1461
+ """,
1462
+ "",
1463
+ )
1464
+ self.init_file(doc_file, test_doc)
1465
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
1466
+ self.check_result(new_doc_file, test_new_doc_pt_only)
1467
+
1468
+ test_new_doc_no_tok = test_new_doc.replace(
1469
+ """
1470
+ ## GPTNewNewTokenizer
1471
+
1472
+ [[autodoc]] GPTNewNewTokenizer
1473
+ - save_vocabulary
1474
+
1475
+ ## GPTNewNewTokenizerFast
1476
+
1477
+ [[autodoc]] GPTNewNewTokenizerFast
1478
+ """,
1479
+ "",
1480
+ )
1481
+ new_model_patterns = ModelPatterns(
1482
+ "GPT-New New", "huggingface/gpt-new-new", tokenizer_class="GPT2Tokenizer"
1483
+ )
1484
+ self.init_file(doc_file, test_doc)
1485
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns)
1486
+ print(test_new_doc_no_tok)
1487
+ self.check_result(new_doc_file, test_new_doc_no_tok)
1488
+
1489
+ test_new_doc_pt_only_no_tok = test_new_doc_no_tok.replace(
1490
+ """
1491
+ ## TFGPTNewNewModel
1492
+
1493
+ [[autodoc]] TFGPTNewNewModel
1494
+ - call
1495
+
1496
+ ## FlaxGPTNewNewModel
1497
+
1498
+ [[autodoc]] FlaxGPTNewNewModel
1499
+ - __call__
1500
+
1501
+ """,
1502
+ "",
1503
+ )
1504
+ self.init_file(doc_file, test_doc)
1505
+ duplicate_doc_file(doc_file, gpt2_model_patterns, new_model_patterns, frameworks=["pt"])
1506
+ self.check_result(new_doc_file, test_new_doc_pt_only_no_tok)
docs/transformers/tests/utils/test_audio_utils.py ADDED
@@ -0,0 +1,1751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import numpy as np
18
+ import pytest
19
+
20
+ from transformers.audio_utils import (
21
+ amplitude_to_db,
22
+ amplitude_to_db_batch,
23
+ chroma_filter_bank,
24
+ hertz_to_mel,
25
+ mel_filter_bank,
26
+ mel_to_hertz,
27
+ power_to_db,
28
+ power_to_db_batch,
29
+ spectrogram,
30
+ spectrogram_batch,
31
+ window_function,
32
+ )
33
+ from transformers.testing_utils import is_librosa_available, require_librosa
34
+
35
+
36
+ if is_librosa_available():
37
+ from librosa.filters import chroma
38
+
39
+
40
+ class AudioUtilsFunctionTester(unittest.TestCase):
41
+ # will be set in `def _load_datasamples`
42
+ _dataset = None
43
+
44
+ def test_hertz_to_mel(self):
45
+ self.assertEqual(hertz_to_mel(0.0), 0.0)
46
+ self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)
47
+
48
+ inputs = np.array([100, 200])
49
+ expected = np.array([150.48910241, 283.22989816])
50
+ self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))
51
+
52
+ self.assertEqual(hertz_to_mel(0.0, "slaney"), 0.0)
53
+ self.assertEqual(hertz_to_mel(100, "slaney"), 1.5)
54
+
55
+ inputs = np.array([60, 100, 200, 1000, 1001, 2000])
56
+ expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
57
+ self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))
58
+
59
+ inputs = np.array([60, 100, 200, 1000, 1001, 2000])
60
+ expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
61
+ self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))
62
+
63
+ with pytest.raises(ValueError):
64
+ hertz_to_mel(100, mel_scale=None)
65
+
66
+ def test_mel_to_hertz(self):
67
+ self.assertEqual(mel_to_hertz(0.0), 0.0)
68
+ self.assertAlmostEqual(mel_to_hertz(150.48910241), 100)
69
+
70
+ inputs = np.array([150.48910241, 283.22989816])
71
+ expected = np.array([100, 200])
72
+ self.assertTrue(np.allclose(mel_to_hertz(inputs), expected))
73
+
74
+ self.assertEqual(mel_to_hertz(0.0, "slaney"), 0.0)
75
+ self.assertEqual(mel_to_hertz(1.5, "slaney"), 100)
76
+
77
+ inputs = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
78
+ expected = np.array([60, 100, 200, 1000, 1001, 2000])
79
+ self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))
80
+
81
+ inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
82
+ expected = np.array([60, 100, 200, 1000, 1001, 2000])
83
+ self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))
84
+
85
+ with pytest.raises(ValueError):
86
+ mel_to_hertz(100, mel_scale=None)
87
+
88
+ def test_mel_filter_bank_shape(self):
89
+ mel_filters = mel_filter_bank(
90
+ num_frequency_bins=513,
91
+ num_mel_filters=13,
92
+ min_frequency=100,
93
+ max_frequency=4000,
94
+ sampling_rate=16000,
95
+ norm=None,
96
+ mel_scale="htk",
97
+ )
98
+ self.assertEqual(mel_filters.shape, (513, 13))
99
+
100
+ mel_filters = mel_filter_bank(
101
+ num_frequency_bins=513,
102
+ num_mel_filters=13,
103
+ min_frequency=100,
104
+ max_frequency=4000,
105
+ sampling_rate=16000,
106
+ norm="slaney",
107
+ mel_scale="slaney",
108
+ )
109
+ self.assertEqual(mel_filters.shape, (513, 13))
110
+
111
+ mel_filters = mel_filter_bank(
112
+ num_frequency_bins=513,
113
+ num_mel_filters=13,
114
+ min_frequency=100,
115
+ max_frequency=4000,
116
+ sampling_rate=16000,
117
+ norm="slaney",
118
+ mel_scale="slaney",
119
+ triangularize_in_mel_space=True,
120
+ )
121
+ self.assertEqual(mel_filters.shape, (513, 13))
122
+
123
+ def test_mel_filter_bank_htk(self):
124
+ mel_filters = mel_filter_bank(
125
+ num_frequency_bins=16,
126
+ num_mel_filters=4,
127
+ min_frequency=0,
128
+ max_frequency=2000,
129
+ sampling_rate=4000,
130
+ norm=None,
131
+ mel_scale="htk",
132
+ )
133
+ # fmt: off
134
+ expected = np.array([
135
+ [0.0 , 0.0 , 0.0 , 0.0 ],
136
+ [0.61454786, 0.0 , 0.0 , 0.0 ],
137
+ [0.82511046, 0.17488954, 0.0 , 0.0 ],
138
+ [0.35597035, 0.64402965, 0.0 , 0.0 ],
139
+ [0.0 , 0.91360726, 0.08639274, 0.0 ],
140
+ [0.0 , 0.55547007, 0.44452993, 0.0 ],
141
+ [0.0 , 0.19733289, 0.80266711, 0.0 ],
142
+ [0.0 , 0.0 , 0.87724349, 0.12275651],
143
+ [0.0 , 0.0 , 0.6038449 , 0.3961551 ],
144
+ [0.0 , 0.0 , 0.33044631, 0.66955369],
145
+ [0.0 , 0.0 , 0.05704771, 0.94295229],
146
+ [0.0 , 0.0 , 0.0 , 0.83483975],
147
+ [0.0 , 0.0 , 0.0 , 0.62612982],
148
+ [0.0 , 0.0 , 0.0 , 0.41741988],
149
+ [0.0 , 0.0 , 0.0 , 0.20870994],
150
+ [0.0 , 0.0 , 0.0 , 0.0 ]
151
+ ])
152
+ # fmt: on
153
+ self.assertTrue(np.allclose(mel_filters, expected))
154
+
155
+ def test_mel_filter_bank_slaney(self):
156
+ mel_filters = mel_filter_bank(
157
+ num_frequency_bins=16,
158
+ num_mel_filters=4,
159
+ min_frequency=0,
160
+ max_frequency=2000,
161
+ sampling_rate=4000,
162
+ norm=None,
163
+ mel_scale="slaney",
164
+ )
165
+ # fmt: off
166
+ expected = np.array([
167
+ [0.0 , 0.0 , 0.0 , 0.0 ],
168
+ [0.39869419, 0.0 , 0.0 , 0.0 ],
169
+ [0.79738839, 0.0 , 0.0 , 0.0 ],
170
+ [0.80391742, 0.19608258, 0.0 , 0.0 ],
171
+ [0.40522322, 0.59477678, 0.0 , 0.0 ],
172
+ [0.00652903, 0.99347097, 0.0 , 0.0 ],
173
+ [0.0 , 0.60796161, 0.39203839, 0.0 ],
174
+ [0.0 , 0.20939631, 0.79060369, 0.0 ],
175
+ [0.0 , 0.0 , 0.84685344, 0.15314656],
176
+ [0.0 , 0.0 , 0.52418477, 0.47581523],
177
+ [0.0 , 0.0 , 0.2015161 , 0.7984839 ],
178
+ [0.0 , 0.0 , 0.0 , 0.9141874 ],
179
+ [0.0 , 0.0 , 0.0 , 0.68564055],
180
+ [0.0 , 0.0 , 0.0 , 0.4570937 ],
181
+ [0.0 , 0.0 , 0.0 , 0.22854685],
182
+ [0.0 , 0.0 , 0.0 , 0.0 ]
183
+ ])
184
+ # fmt: on
185
+ self.assertTrue(np.allclose(mel_filters, expected))
186
+
187
+ def test_mel_filter_bank_kaldi(self):
188
+ mel_filters = mel_filter_bank(
189
+ num_frequency_bins=16,
190
+ num_mel_filters=4,
191
+ min_frequency=0,
192
+ max_frequency=2000,
193
+ sampling_rate=4000,
194
+ norm=None,
195
+ mel_scale="kaldi",
196
+ triangularize_in_mel_space=True,
197
+ )
198
+ # fmt: off
199
+ # here the expected values from torchaudio.compliance.kaldi.get_mel_banks
200
+ # note that we compute values in float64 while they do it in float32
201
+ expected = np.array(
202
+ [
203
+ [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000],
204
+ [0.6457883715629578, 0.0000000000000000, 0.0000000000000000, 0.0000000000000000],
205
+ [0.8044781088829041, 0.1955219060182571, 0.0000000000000000, 0.0000000000000000],
206
+ [0.3258901536464691, 0.6741098165512085, 0.0000000000000000, 0.0000000000000000],
207
+ [0.0000000000000000, 0.9021250009536743, 0.0978749766945839, 0.0000000000000000],
208
+ [0.0000000000000000, 0.5219038724899292, 0.4780961275100708, 0.0000000000000000],
209
+ [0.0000000000000000, 0.1771058291196823, 0.8228941559791565, 0.0000000000000000],
210
+ [0.0000000000000000, 0.0000000000000000, 0.8616894483566284, 0.1383105516433716],
211
+ [0.0000000000000000, 0.0000000000000000, 0.5710380673408508, 0.4289619624614716],
212
+ [0.0000000000000000, 0.0000000000000000, 0.3015440106391907, 0.6984559893608093],
213
+ [0.0000000000000000, 0.0000000000000000, 0.0503356307744980, 0.9496643543243408],
214
+ [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.8150880336761475],
215
+ [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.5938932299613953],
216
+ [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.3851676583290100],
217
+ [0.0000000000000000, 0.0000000000000000, 0.0000000000000000, 0.1875794380903244],
218
+ ],
219
+ dtype=np.float64,
220
+ )
221
+ # fmt: on
222
+
223
+ # kaldi implementation does not compute values for last fft bin
224
+ # indeed, they enforce max_frequency <= sampling_rate / 2 and
225
+ # therefore they know that last fft bin filter bank values will be all 0
226
+ # and pad after with zeros
227
+ # to comply with our API for `mel_filter_bank`, we need to also pad here
228
+ expected = np.pad(expected, ((0, 1), (0, 0)))
229
+
230
+ self.assertTrue(np.allclose(mel_filters, expected))
231
+
232
+ def test_mel_filter_bank_slaney_norm(self):
233
+ mel_filters = mel_filter_bank(
234
+ num_frequency_bins=16,
235
+ num_mel_filters=4,
236
+ min_frequency=0,
237
+ max_frequency=2000,
238
+ sampling_rate=4000,
239
+ norm="slaney",
240
+ mel_scale="slaney",
241
+ )
242
+ # fmt: off
243
+ expected = np.array([
244
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
245
+ [1.19217795e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
246
+ [2.38435591e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
247
+ [2.40387905e-03, 5.86232616e-04, 0.00000000e+00, 0.00000000e+00],
248
+ [1.21170110e-03, 1.77821783e-03, 0.00000000e+00, 0.00000000e+00],
249
+ [1.95231437e-05, 2.97020305e-03, 0.00000000e+00, 0.00000000e+00],
250
+ [0.00000000e+00, 1.81763684e-03, 1.04857612e-03, 0.00000000e+00],
251
+ [0.00000000e+00, 6.26036972e-04, 2.11460963e-03, 0.00000000e+00],
252
+ [0.00000000e+00, 0.00000000e+00, 2.26505954e-03, 3.07332945e-04],
253
+ [0.00000000e+00, 0.00000000e+00, 1.40202503e-03, 9.54861093e-04],
254
+ [0.00000000e+00, 0.00000000e+00, 5.38990521e-04, 1.60238924e-03],
255
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.83458185e-03],
256
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.37593638e-03],
257
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 9.17290923e-04],
258
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 4.58645462e-04],
259
+ [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]
260
+ ])
261
+ # fmt: on
262
+ self.assertTrue(np.allclose(mel_filters, expected))
263
+
264
+ def test_window_function(self):
265
+ window = window_function(16, "hann")
266
+ self.assertEqual(len(window), 16)
267
+
268
+ # fmt: off
269
+ expected = np.array([
270
+ 0.0, 0.03806023, 0.14644661, 0.30865828, 0.5, 0.69134172, 0.85355339, 0.96193977,
271
+ 1.0, 0.96193977, 0.85355339, 0.69134172, 0.5, 0.30865828, 0.14644661, 0.03806023,
272
+ ])
273
+ # fmt: on
274
+ self.assertTrue(np.allclose(window, expected))
275
+
276
+ def _load_datasamples(self, num_samples):
277
+ from datasets import load_dataset
278
+
279
+ if self._dataset is None:
280
+ self._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
281
+ speech_samples = self._dataset.sort("id").select(range(num_samples))[:num_samples]["audio"]
282
+ return [x["array"] for x in speech_samples]
283
+
284
+ def test_spectrogram_impulse(self):
285
+ waveform = np.zeros(40)
286
+ waveform[9] = 1.0 # impulse shifted in time
287
+
288
+ spec = spectrogram(
289
+ waveform,
290
+ window_function(12, "hann", frame_length=16),
291
+ frame_length=16,
292
+ hop_length=4,
293
+ power=1.0,
294
+ center=True,
295
+ pad_mode="reflect",
296
+ onesided=True,
297
+ )
298
+ self.assertEqual(spec.shape, (9, 11))
299
+
300
+ expected = np.array([[0.0, 0.0669873, 0.9330127, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
301
+ self.assertTrue(np.allclose(spec, expected))
302
+
303
+ def test_spectrogram_batch_impulse(self):
304
+ waveform1 = np.zeros(40)
305
+ waveform1[9] = 1.0
306
+
307
+ waveform2 = np.zeros(28)
308
+ waveform2[12] = 3.0
309
+
310
+ waveform3 = np.zeros(51)
311
+ waveform3[26] = 4.5
312
+
313
+ waveform_list = [waveform1, waveform2, waveform3]
314
+
315
+ spec_list = spectrogram_batch(
316
+ waveform_list,
317
+ window_function(12, "hann", frame_length=16),
318
+ frame_length=16,
319
+ hop_length=4,
320
+ power=1.0,
321
+ center=True,
322
+ pad_mode="reflect",
323
+ onesided=True,
324
+ )
325
+
326
+ self.assertEqual(spec_list[0].shape, (9, 11))
327
+ self.assertEqual(spec_list[1].shape, (9, 8))
328
+ self.assertEqual(spec_list[2].shape, (9, 13))
329
+
330
+ expected1 = np.array([[0.0, 0.0669873, 0.9330127, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
331
+ expected2 = np.array([[0.0, 0.0, 0.75, 3.0, 0.75, 0.0, 0.0, 0.0]])
332
+ expected3 = np.array([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.375, 3.375, 0.0, 0.0, 0.0, 0.0, 0.0]])
333
+
334
+ self.assertTrue(np.allclose(spec_list[0], expected1))
335
+ self.assertTrue(np.allclose(spec_list[1], expected2))
336
+ self.assertTrue(np.allclose(spec_list[2], expected3))
337
+
338
+ def test_spectrogram_integration_test(self):
339
+ waveform = self._load_datasamples(1)[0]
340
+
341
+ spec = spectrogram(
342
+ waveform,
343
+ window_function(400, "hann", frame_length=512),
344
+ frame_length=512,
345
+ hop_length=128,
346
+ power=1.0,
347
+ center=True,
348
+ pad_mode="reflect",
349
+ onesided=True,
350
+ )
351
+ self.assertEqual(spec.shape, (257, 732))
352
+
353
+ # fmt: off
354
+ expected = np.array([
355
+ 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,
356
+ 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,
357
+ 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,
358
+ 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,
359
+ 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,
360
+ 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,
361
+ 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,
362
+ 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,
363
+ 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,
364
+ 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,
365
+ 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,
366
+ 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,
367
+ 0.0293578 , 0.03452379, 0.02194803, 0.01676056,
368
+ ])
369
+ # fmt: on
370
+ self.assertTrue(np.allclose(spec[:64, 400], expected))
371
+
372
+ spec = spectrogram(
373
+ waveform,
374
+ window_function(400, "hann"),
375
+ frame_length=400,
376
+ hop_length=128,
377
+ fft_length=512,
378
+ power=1.0,
379
+ center=True,
380
+ pad_mode="reflect",
381
+ onesided=True,
382
+ )
383
+ self.assertEqual(spec.shape, (257, 732))
384
+ self.assertTrue(np.allclose(spec[:64, 400], expected))
385
+
386
+ mel_filters = mel_filter_bank(
387
+ num_frequency_bins=257,
388
+ num_mel_filters=400,
389
+ min_frequency=20,
390
+ max_frequency=8000,
391
+ sampling_rate=16000,
392
+ norm=None,
393
+ mel_scale="kaldi",
394
+ triangularize_in_mel_space=True,
395
+ )
396
+
397
+ spec = spectrogram(
398
+ waveform,
399
+ window_function(400, "povey", periodic=False),
400
+ frame_length=400,
401
+ hop_length=160,
402
+ fft_length=512,
403
+ power=2.0,
404
+ center=False,
405
+ pad_mode="reflect",
406
+ onesided=True,
407
+ preemphasis=0.97,
408
+ mel_filters=mel_filters,
409
+ log_mel="log",
410
+ mel_floor=1.1920928955078125e-07,
411
+ remove_dc_offset=True,
412
+ )
413
+ self.assertEqual(spec.shape, (400, 584))
414
+
415
+ # fmt: off
416
+ expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
417
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
418
+ -6.52463769, -7.73677889, -15.94238515, -15.94238515,
419
+ -15.94238515, -15.94238515, -4.18650018, -3.37195286,
420
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
421
+ -4.70190154, -2.4217066 , -15.94238515, -15.94238515,
422
+ -15.94238515, -15.94238515, -5.62755239, -3.53385194,
423
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
424
+ -9.43303023, -8.77480925, -15.94238515, -15.94238515,
425
+ -15.94238515, -15.94238515, -4.2951092 , -5.51585994,
426
+ -15.94238515, -15.94238515, -15.94238515, -4.40151721,
427
+ -3.95228878, -15.94238515, -15.94238515, -15.94238515,
428
+ -6.10365415, -4.59494697, -15.94238515, -15.94238515,
429
+ -15.94238515, -8.10727767, -6.2585298 , -15.94238515,
430
+ -15.94238515, -15.94238515, -5.60161702, -4.47217004,
431
+ -15.94238515, -15.94238515, -15.94238515, -5.91641988]
432
+ )
433
+ # fmt: on
434
+ self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))
435
+
436
+ def test_spectrogram_batch_integration_test(self):
437
+ waveform_list = self._load_datasamples(3)
438
+
439
+ spec_list = spectrogram_batch(
440
+ waveform_list,
441
+ window_function(400, "hann", frame_length=512),
442
+ frame_length=512,
443
+ hop_length=128,
444
+ power=1.0,
445
+ center=True,
446
+ pad_mode="reflect",
447
+ onesided=True,
448
+ )
449
+ self.assertEqual(spec_list[0].shape, (257, 732))
450
+ self.assertEqual(spec_list[1].shape, (257, 602))
451
+ self.assertEqual(spec_list[2].shape, (257, 1561))
452
+
453
+ # fmt: off
454
+ expected1 = np.array([
455
+ 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,
456
+ 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,
457
+ 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,
458
+ 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,
459
+ 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,
460
+ 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,
461
+ 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,
462
+ 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,
463
+ 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,
464
+ 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,
465
+ 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,
466
+ 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,
467
+ 0.0293578 , 0.03452379, 0.02194803, 0.01676056,
468
+ ])
469
+ expected2 = np.array([
470
+ 7.61983171e-02, 1.45338190e-01, 2.63903728e+00, 7.74429535e+00,
471
+ 9.61932980e+00, 5.40767686e+00, 1.08924884e+00, 3.40908262e+00,
472
+ 3.59484250e+00, 1.68451077e+00, 5.88405873e-01, 1.17042530e+00,
473
+ 9.94803324e-01, 3.53757065e-01, 5.47699239e-01, 9.48368581e-01,
474
+ 7.17770457e-01, 2.09396633e-01, 1.77574463e-01, 2.35644731e-01,
475
+ 1.31535991e-01, 1.53539552e-02, 4.34416305e-02, 5.32897267e-02,
476
+ 4.03567305e-02, 1.41842226e-02, 2.90514538e-02, 3.36549485e-02,
477
+ 1.53516624e-02, 2.37464225e-02, 4.60092464e-02, 4.05769324e-02,
478
+ 4.82633401e-03, 4.12675364e-02, 7.13859796e-02, 6.16866566e-02,
479
+ 2.55657822e-02, 1.68923281e-02, 1.91299946e-02, 1.60033798e-02,
480
+ 1.33405095e-02, 1.52065457e-02, 1.21833352e-02, 2.25786382e-03,
481
+ 6.15358376e-03, 1.07647616e-02, 1.23051018e-02, 6.75289378e-03,
482
+ 2.71127435e-03, 1.06515263e-02, 1.18463583e-02, 7.14347935e-03,
483
+ 1.87912782e-03, 4.44236027e-03, 5.19630243e-03, 2.46666998e-03,
484
+ 1.01598645e-03, 1.21589237e-03, 1.29095500e-03, 1.07447628e-03,
485
+ 1.40218156e-03, 3.65402623e-03, 4.00592755e-03, 4.20001841e-03
486
+ ])
487
+ expected3 = np.array([
488
+ 0.07805249, 0.34305022, 0.55617084, 1.22475182, 1.17040678,
489
+ 0.51540532, 0.23570016, 0.06630775, 0.09017777, 0.07693192,
490
+ 0.0333643 , 0.04873054, 0.04668559, 0.02384041, 0.02780435,
491
+ 0.0289717 , 0.01704903, 0.0201644 , 0.01700376, 0.02176975,
492
+ 0.02042491, 0.00732129, 0.00326042, 0.00245065, 0.00510645,
493
+ 0.00681892, 0.00739329, 0.00551437, 0.0070674 , 0.00630015,
494
+ 0.00379566, 0.0060098 , 0.00311543, 0.00902284, 0.01171038,
495
+ 0.01202166, 0.01759194, 0.01652899, 0.01201872, 0.01295351,
496
+ 0.00756432, 0.01415318, 0.02349972, 0.02296833, 0.02429341,
497
+ 0.02447459, 0.01835044, 0.01437871, 0.02262246, 0.02972324,
498
+ 0.03392252, 0.03037546, 0.01116927, 0.01555062, 0.02833379,
499
+ 0.02294212, 0.02069847, 0.02496927, 0.02273526, 0.01341643,
500
+ 0.00805407, 0.00624943, 0.01076262, 0.01876003
501
+ ])
502
+ # fmt: on
503
+ self.assertTrue(np.allclose(spec_list[0][:64, 400], expected1))
504
+ self.assertTrue(np.allclose(spec_list[1][:64, 400], expected2))
505
+ self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3))
506
+
507
+ spec_list = spectrogram_batch(
508
+ waveform_list,
509
+ window_function(400, "hann"),
510
+ frame_length=400,
511
+ hop_length=128,
512
+ fft_length=512,
513
+ power=1.0,
514
+ center=True,
515
+ pad_mode="reflect",
516
+ onesided=True,
517
+ )
518
+ self.assertEqual(spec_list[0].shape, (257, 732))
519
+ self.assertEqual(spec_list[1].shape, (257, 602))
520
+ self.assertEqual(spec_list[2].shape, (257, 1561))
521
+ self.assertTrue(np.allclose(spec_list[0][:64, 400], expected1))
522
+ self.assertTrue(np.allclose(spec_list[1][:64, 400], expected2))
523
+ self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3))
524
+
525
+ mel_filters = mel_filter_bank(
526
+ num_frequency_bins=257,
527
+ num_mel_filters=400,
528
+ min_frequency=20,
529
+ max_frequency=8000,
530
+ sampling_rate=16000,
531
+ norm=None,
532
+ mel_scale="kaldi",
533
+ triangularize_in_mel_space=True,
534
+ )
535
+
536
+ spec_list = spectrogram_batch(
537
+ waveform_list,
538
+ window_function(400, "povey", periodic=False),
539
+ frame_length=400,
540
+ hop_length=160,
541
+ fft_length=512,
542
+ power=2.0,
543
+ center=False,
544
+ pad_mode="reflect",
545
+ onesided=True,
546
+ preemphasis=0.97,
547
+ mel_filters=mel_filters,
548
+ log_mel="log",
549
+ mel_floor=1.1920928955078125e-07,
550
+ remove_dc_offset=True,
551
+ )
552
+ self.assertEqual(spec_list[0].shape, (400, 584))
553
+ self.assertEqual(spec_list[1].shape, (400, 480))
554
+ self.assertEqual(spec_list[2].shape, (400, 1247))
555
+
556
+ # fmt: off
557
+ expected1 = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
558
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
559
+ -6.52463769, -7.73677889, -15.94238515, -15.94238515,
560
+ -15.94238515, -15.94238515, -4.18650018, -3.37195286,
561
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
562
+ -4.70190154, -2.4217066 , -15.94238515, -15.94238515,
563
+ -15.94238515, -15.94238515, -5.62755239, -3.53385194,
564
+ -15.94238515, -15.94238515, -15.94238515, -15.94238515,
565
+ -9.43303023, -8.77480925, -15.94238515, -15.94238515,
566
+ -15.94238515, -15.94238515, -4.2951092 , -5.51585994,
567
+ -15.94238515, -15.94238515, -15.94238515, -4.40151721,
568
+ -3.95228878, -15.94238515, -15.94238515, -15.94238515,
569
+ -6.10365415, -4.59494697, -15.94238515, -15.94238515,
570
+ -15.94238515, -8.10727767, -6.2585298 , -15.94238515,
571
+ -15.94238515, -15.94238515, -5.60161702, -4.47217004,
572
+ -15.94238515, -15.94238515, -15.94238515, -5.91641988]
573
+ )
574
+ expected2 = np.array([-15.942385, -8.531508, -8.551396, -15.942385, -15.942385,
575
+ -15.942385, -15.942385, -15.942385, -5.626043, -6.8381968,
576
+ -15.942385, -15.942385, -15.942385, -15.942385, -3.3122184,
577
+ -2.49764, -15.942385, -15.942385, -15.942385, -15.942385,
578
+ -3.625868, -1.3457257, -15.942385, -15.942385, -15.942385,
579
+ -15.942385, -4.2223063, -2.1285915, -15.942385, -15.942385,
580
+ -15.942385, -15.942385, -8.611152, -7.952894, -15.942385,
581
+ -15.942385, -15.942385, -15.942385, -2.7585578, -3.9793255,
582
+ -15.942385, -15.942385, -15.942385, -2.5377562, -2.0885658,
583
+ -15.942385, -15.942385, -15.942385, -3.8310733, -2.322393,
584
+ -15.942385, -15.942385, -15.942385, -7.674944, -5.8261633,
585
+ -15.942385, -15.942385, -15.942385, -3.5960004, -2.4665844,
586
+ -15.942385, -15.942385, -15.942385, -1.7905309]
587
+ )
588
+ expected3 = np.array([-15.942385, -13.406995, -13.426883, -15.942385, -15.942385,
589
+ -15.942385, -15.942385, -15.942385, -15.942385, -15.942385,
590
+ -15.942385, -15.942385, -15.942385, -15.942385, -13.493383,
591
+ -12.678805, -15.942385, -15.942385, -15.942385, -15.942385,
592
+ -14.809377, -12.529235, -15.942385, -15.942385, -15.942385,
593
+ -15.942385, -13.838827, -11.745112, -15.942385, -15.942385,
594
+ -15.942385, -15.942385, -13.9336405, -13.275384, -15.942385,
595
+ -15.942385, -15.942385, -15.942385, -13.043786, -14.264554,
596
+ -15.942385, -15.942385, -15.942385, -13.060181, -12.610991,
597
+ -15.942385, -15.942385, -15.942385, -14.152064, -12.643384,
598
+ -15.942385, -15.942385, -15.942385, -14.48317, -12.634389,
599
+ -15.942385, -15.942385, -15.942385, -14.627316, -13.4979,
600
+ -15.942385, -15.942385, -15.942385, -12.6279955]
601
+ )
602
+ # fmt: on
603
+ self.assertTrue(np.allclose(spec_list[0][:64, 400], expected1, atol=1e-5))
604
+ self.assertTrue(np.allclose(spec_list[1][:64, 400], expected2, atol=1e-5))
605
+ self.assertTrue(np.allclose(spec_list[2][:64, 400], expected3, atol=1e-5))
606
+
607
+ def test_spectrogram_center_padding(self):
608
+ waveform = self._load_datasamples(1)[0]
609
+
610
+ spec = spectrogram(
611
+ waveform,
612
+ window_function(512, "hann"),
613
+ frame_length=512,
614
+ hop_length=128,
615
+ center=True,
616
+ pad_mode="reflect",
617
+ )
618
+ self.assertEqual(spec.shape, (257, 732))
619
+
620
+ # fmt: off
621
+ expected = np.array([
622
+ 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,
623
+ 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,
624
+ 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,
625
+ 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,
626
+ 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,
627
+ 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,
628
+ 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,
629
+ 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,
630
+ 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,
631
+ 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,
632
+ 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,
633
+ 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,
634
+ 0.00217659, 0.00276204, 0.00260835, 0.00299299,
635
+ ])
636
+ # fmt: on
637
+ self.assertTrue(np.allclose(spec[:64, 0], expected))
638
+
639
+ spec = spectrogram(
640
+ waveform,
641
+ window_function(512, "hann"),
642
+ frame_length=512,
643
+ hop_length=128,
644
+ center=True,
645
+ pad_mode="constant",
646
+ )
647
+ self.assertEqual(spec.shape, (257, 732))
648
+
649
+ # fmt: off
650
+ expected = np.array([
651
+ 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,
652
+ 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,
653
+ 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,
654
+ 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,
655
+ 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,
656
+ 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,
657
+ 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,
658
+ 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,
659
+ 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,
660
+ 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,
661
+ 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,
662
+ 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,
663
+ 0.00788239, 0.00664407, 0.00824227, 0.00628301,
664
+ ])
665
+ # fmt: on
666
+ self.assertTrue(np.allclose(spec[:64, 0], expected))
667
+
668
+ spec = spectrogram(
669
+ waveform,
670
+ window_function(512, "hann"),
671
+ frame_length=512,
672
+ hop_length=128,
673
+ center=False,
674
+ )
675
+ self.assertEqual(spec.shape, (257, 728))
676
+
677
+ # fmt: off
678
+ expected = np.array([
679
+ 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,
680
+ 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,
681
+ 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,
682
+ 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,
683
+ 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,
684
+ 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,
685
+ 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,
686
+ 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,
687
+ 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,
688
+ 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,
689
+ 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,
690
+ 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,
691
+ 0.00811857, 0.00538216, 0.00685749, 0.00535275,
692
+ ])
693
+ # fmt: on
694
+ self.assertTrue(np.allclose(spec[:64, 0], expected))
695
+
696
+ def test_spectrogram_batch_center_padding(self):
697
+ waveform_list = self._load_datasamples(3)
698
+
699
+ spec_list = spectrogram_batch(
700
+ waveform_list,
701
+ window_function(512, "hann"),
702
+ frame_length=512,
703
+ hop_length=128,
704
+ center=True,
705
+ pad_mode="reflect",
706
+ )
707
+ self.assertEqual(spec_list[0].shape, (257, 732))
708
+ self.assertEqual(spec_list[1].shape, (257, 602))
709
+ self.assertEqual(spec_list[2].shape, (257, 1561))
710
+
711
+ # fmt: off
712
+ expected1 = np.array([
713
+ 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,
714
+ 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,
715
+ 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,
716
+ 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,
717
+ 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,
718
+ 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,
719
+ 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,
720
+ 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,
721
+ 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,
722
+ 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,
723
+ 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,
724
+ 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,
725
+ 0.00217659, 0.00276204, 0.00260835, 0.00299299,
726
+ ])
727
+ expected2 = np.array([
728
+ 1.89624839e-02, 1.23274978e-02, 3.69160250e-02, 4.76267971e-02,
729
+ 1.39258439e-02, 2.98370440e-02, 2.74845166e-03, 3.01934010e-03,
730
+ 1.18722776e-02, 9.70834121e-03, 2.06300567e-04, 6.32975250e-04,
731
+ 8.20603687e-03, 1.21864351e-02, 3.28791840e-03, 3.36801982e-04,
732
+ 2.79373326e-03, 5.00530424e-03, 8.46884679e-03, 1.14089288e-02,
733
+ 8.59052036e-03, 2.88538425e-03, 9.95071139e-03, 6.80431770e-03,
734
+ 2.95809377e-03, 1.46285209e-04, 3.36268265e-03, 4.80051298e-04,
735
+ 2.84506916e-03, 9.34222655e-04, 3.42161348e-03, 2.79612141e-03,
736
+ 3.38875921e-03, 2.85030343e-03, 5.39513239e-05, 2.72908504e-03,
737
+ 2.09591188e-03, 5.00271388e-04, 8.31917219e-04, 2.37967237e-03,
738
+ 1.75001193e-03, 1.31826295e-04, 8.83622793e-04, 1.54303256e-04,
739
+ 3.09544569e-03, 4.08527814e-03, 2.73566321e-03, 1.78805250e-03,
740
+ 9.53314066e-06, 1.74316950e-03, 1.51099428e-03, 8.65990878e-04,
741
+ 8.44859460e-04, 5.35220199e-04, 5.36562002e-04, 8.33181897e-04,
742
+ 8.22705682e-04, 1.81083288e-03, 9.75003233e-04, 6.73114730e-04,
743
+ 6.81665202e-04, 2.05180887e-03, 1.10151991e-03, 4.75923851e-04,
744
+ ])
745
+ expected3 = np.array([
746
+ 0.07079848, 0.04237922, 0.0220724, 0.04446052, 0.03598337,
747
+ 0.03327273, 0.02545774, 0.01319528, 0.00919659, 0.01376867,
748
+ 0.00361992, 0.00608425, 0.01105873, 0.0105565, 0.00744286,
749
+ 0.00244849, 0.00257317, 0.00749989, 0.01061386, 0.01525312,
750
+ 0.00656914, 0.01199581, 0.00487319, 0.00830956, 0.0046706,
751
+ 0.00588962, 0.00544486, 0.00565179, 0.00050112, 0.01108059,
752
+ 0.00217417, 0.00453234, 0.00537306, 0.00269329, 0.00342333,
753
+ 0.00095484, 0.00708934, 0.00660373, 0.00543686, 0.00217186,
754
+ 0.00431519, 0.00457764, 0.00503529, 0.01166454, 0.01375581,
755
+ 0.01467224, 0.00873404, 0.00534086, 0.00476848, 0.0226163,
756
+ 0.0314, 0.00151021, 0.01975221, 0.01637519, 0.00046068,
757
+ 0.0460544, 0.06285986, 0.03151625, 0.0013598, 0.004804,
758
+ 0.0073824, 0.02312599, 0.02613977, 0.01056851
759
+ ])
760
+ # fmt: on
761
+ self.assertTrue(np.allclose(spec_list[0][:64, 0], expected1))
762
+ self.assertTrue(np.allclose(spec_list[1][:64, 0], expected2))
763
+ self.assertTrue(np.allclose(spec_list[2][:64, 0], expected3))
764
+
765
+ spec_list = spectrogram_batch(
766
+ waveform_list,
767
+ window_function(512, "hann"),
768
+ frame_length=512,
769
+ hop_length=128,
770
+ center=True,
771
+ pad_mode="constant",
772
+ )
773
+ self.assertEqual(spec_list[0].shape, (257, 732))
774
+ self.assertEqual(spec_list[1].shape, (257, 602))
775
+ self.assertEqual(spec_list[2].shape, (257, 1561))
776
+
777
+ # fmt: off
778
+ expected1 = np.array([
779
+ 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,
780
+ 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,
781
+ 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,
782
+ 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,
783
+ 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,
784
+ 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,
785
+ 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,
786
+ 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,
787
+ 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,
788
+ 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,
789
+ 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,
790
+ 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,
791
+ 0.00788239, 0.00664407, 0.00824227, 0.00628301,
792
+ ])
793
+ expected2 = np.array([
794
+ 0.00955754, 0.01445548, 0.02393902, 0.02903068, 0.02512844,
795
+ 0.01508297, 0.00474784, 0.00440362, 0.0073898, 0.00546519,
796
+ 0.00126077, 0.00240507, 0.00523254, 0.00632742, 0.00415215,
797
+ 0.00056628, 0.00161288, 0.0026956, 0.00431587, 0.00621471,
798
+ 0.00791291, 0.0079454, 0.00594525, 0.00334581, 0.00180047,
799
+ 0.00144485, 0.00175764, 0.00188037, 0.00134889, 0.00150253,
800
+ 0.00178821, 0.00158875, 0.00204339, 0.00266497, 0.00280556,
801
+ 0.00221949, 0.00108956, 0.000532, 0.00108454, 0.00129254,
802
+ 0.00089315, 0.00022803, 0.00038176, 0.0011302, 0.00189306,
803
+ 0.0021964, 0.00203576, 0.00207306, 0.00217727, 0.00174297,
804
+ 0.00103331, 0.00076695, 0.0007422, 0.00061986, 0.00081204,
805
+ 0.00079615, 0.00089417, 0.00105452, 0.00042615, 0.00066372,
806
+ 0.00132765, 0.00122087, 0.00054903, 0.00107945,
807
+ ])
808
+ expected3 = np.array([
809
+ 0.03573493, 0.03625983, 0.03341755, 0.02431477, 0.01770546,
810
+ 0.0169356 , 0.01579034, 0.01600499, 0.01329064, 0.00747957,
811
+ 0.00367372, 0.00403853, 0.00519597, 0.00551022, 0.00532757,
812
+ 0.00367569, 0.00130341, 0.00345149, 0.00520744, 0.00872308,
813
+ 0.01172503, 0.00948154, 0.00344236, 0.00387997, 0.00425455,
814
+ 0.00394357, 0.00711733, 0.00615654, 0.00055756, 0.00656414,
815
+ 0.00852001, 0.00666252, 0.00509767, 0.00246784, 0.00376049,
816
+ 0.00682879, 0.00641118, 0.00469685, 0.00358701, 0.0015552 ,
817
+ 0.00261458, 0.00701979, 0.00929578, 0.00894536, 0.00828491,
818
+ 0.00773528, 0.00552091, 0.00259871, 0.00933179, 0.01588626,
819
+ 0.01697887, 0.01268552, 0.00957255, 0.01204092, 0.02123362,
820
+ 0.03062669, 0.03215763, 0.02629963, 0.01769568, 0.01088869,
821
+ 0.01151334, 0.01378197, 0.01319263, 0.01066859,
822
+ ])
823
+ # fmt: on
824
+ self.assertTrue(np.allclose(spec_list[0][:64, 0], expected1))
825
+ self.assertTrue(np.allclose(spec_list[1][:64, 0], expected2))
826
+ self.assertTrue(np.allclose(spec_list[2][:64, 0], expected3))
827
+
828
+ spec_list = spectrogram_batch(
829
+ waveform_list,
830
+ window_function(512, "hann"),
831
+ frame_length=512,
832
+ hop_length=128,
833
+ center=False,
834
+ )
835
+ self.assertEqual(spec_list[0].shape, (257, 728))
836
+ self.assertEqual(spec_list[1].shape, (257, 598))
837
+ self.assertEqual(spec_list[2].shape, (257, 1557))
838
+
839
+ # fmt: off
840
+ expected1 = np.array([
841
+ 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,
842
+ 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,
843
+ 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,
844
+ 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,
845
+ 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,
846
+ 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,
847
+ 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,
848
+ 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,
849
+ 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,
850
+ 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,
851
+ 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,
852
+ 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,
853
+ 0.00811857, 0.00538216, 0.00685749, 0.00535275,
854
+ ])
855
+ expected2 = np.array([
856
+ 0.01232908, 0.05980514, 0.08285419, 0.01850723, 0.02823627,
857
+ 0.00204369, 0.01372626, 0.00956435, 0.02267217, 0.00947112,
858
+ 0.00355174, 0.00418008, 0.00843608, 0.01559252, 0.01125505,
859
+ 0.00183573, 0.00765051, 0.0109983 , 0.00890545, 0.00583453,
860
+ 0.00115901, 0.00579039, 0.00151353, 0.00395812, 0.00231413,
861
+ 0.00384272, 0.00313914, 0.00072331, 0.00338935, 0.00383328,
862
+ 0.00218129, 0.00284516, 0.00228538, 0.00083603, 0.00111663,
863
+ 0.00235799, 0.00142748, 0.00092908, 0.0012966 , 0.0011403 ,
864
+ 0.0010619 , 0.00158732, 0.00289866, 0.00216709, 0.00313325,
865
+ 0.00361277, 0.00202507, 0.0009948 , 0.00114428, 0.00200851,
866
+ 0.0009234 , 0.00063468, 0.00018746, 0.00100463, 0.00053799,
867
+ 0.00080009, 0.00158291, 0.00172077, 0.00173586, 0.00197127,
868
+ 0.00107058, 0.00043486, 0.0009859 , 0.00215484,
869
+ ])
870
+ expected3 = np.array([
871
+ 0.01864123, 0.06131337, 0.08346292, 0.04936386, 0.02792609,
872
+ 0.01005205, 0.00884826, 0.02198604, 0.02421535, 0.00957573,
873
+ 0.00503561, 0.00241331, 0.00175652, 0.00195889, 0.00453299,
874
+ 0.0020317 , 0.00249264, 0.00517483, 0.01111943, 0.0150079 ,
875
+ 0.01977743, 0.01253825, 0.00517561, 0.01031712, 0.00579466,
876
+ 0.00783679, 0.0071415 , 0.00591847, 0.01510728, 0.01194921,
877
+ 0.00518072, 0.00125978, 0.00577552, 0.01050614, 0.0077644 ,
878
+ 0.0042905 , 0.00278469, 0.00166695, 0.00255013, 0.00578153,
879
+ 0.00586451, 0.00929514, 0.01501226, 0.00741419, 0.00310625,
880
+ 0.00086757, 0.00595618, 0.0053882 , 0.0116266 , 0.02504773,
881
+ 0.02889692, 0.03739442, 0.04730207, 0.03856638, 0.05700104,
882
+ 0.04299267, 0.02153366, 0.03740607, 0.03811468, 0.01575022,
883
+ 0.00676344, 0.01359865, 0.01769319, 0.00907966,
884
+ ])
885
+ # fmt: on
886
+ self.assertTrue(np.allclose(spec_list[0][:64, 0], expected1))
887
+ self.assertTrue(np.allclose(spec_list[1][:64, 0], expected2))
888
+ self.assertTrue(np.allclose(spec_list[2][:64, 0], expected3))
889
+
890
+ def test_spectrogram_shapes(self):
891
+ waveform = self._load_datasamples(1)[0]
892
+
893
+ spec = spectrogram(
894
+ waveform,
895
+ window_function(400, "hann"),
896
+ frame_length=400,
897
+ hop_length=128,
898
+ power=1.0,
899
+ center=True,
900
+ pad_mode="reflect",
901
+ onesided=True,
902
+ )
903
+ self.assertEqual(spec.shape, (201, 732))
904
+
905
+ spec = spectrogram(
906
+ waveform,
907
+ window_function(400, "hann"),
908
+ frame_length=400,
909
+ hop_length=128,
910
+ power=1.0,
911
+ center=False,
912
+ pad_mode="reflect",
913
+ onesided=True,
914
+ )
915
+ self.assertEqual(spec.shape, (201, 729))
916
+
917
+ spec = spectrogram(
918
+ waveform,
919
+ window_function(400, "hann"),
920
+ frame_length=400,
921
+ hop_length=128,
922
+ fft_length=512,
923
+ power=1.0,
924
+ center=True,
925
+ pad_mode="reflect",
926
+ onesided=True,
927
+ )
928
+ self.assertEqual(spec.shape, (257, 732))
929
+
930
+ spec = spectrogram(
931
+ waveform,
932
+ window_function(400, "hann", frame_length=512),
933
+ frame_length=512,
934
+ hop_length=64,
935
+ power=1.0,
936
+ center=True,
937
+ pad_mode="reflect",
938
+ onesided=False,
939
+ )
940
+ self.assertEqual(spec.shape, (512, 1464))
941
+
942
+ spec = spectrogram(
943
+ waveform,
944
+ window_function(512, "hann"),
945
+ frame_length=512,
946
+ hop_length=64,
947
+ power=1.0,
948
+ center=True,
949
+ pad_mode="reflect",
950
+ onesided=False,
951
+ )
952
+ self.assertEqual(spec.shape, (512, 1464))
953
+
954
+ spec = spectrogram(
955
+ waveform,
956
+ window_function(512, "hann"),
957
+ frame_length=512,
958
+ hop_length=512,
959
+ power=1.0,
960
+ center=True,
961
+ pad_mode="reflect",
962
+ onesided=False,
963
+ )
964
+ self.assertEqual(spec.shape, (512, 183))
965
+
966
+ def test_spectrogram_batch_shapes(self):
967
+ waveform_list = self._load_datasamples(3)
968
+
969
+ spec_list = spectrogram_batch(
970
+ waveform_list,
971
+ window_function(400, "hann"),
972
+ frame_length=400,
973
+ hop_length=128,
974
+ power=1.0,
975
+ center=True,
976
+ pad_mode="reflect",
977
+ onesided=True,
978
+ )
979
+ self.assertEqual(spec_list[0].shape, (201, 732))
980
+ self.assertEqual(spec_list[1].shape, (201, 602))
981
+ self.assertEqual(spec_list[2].shape, (201, 1561))
982
+
983
+ spec_list = spectrogram_batch(
984
+ waveform_list,
985
+ window_function(400, "hann"),
986
+ frame_length=400,
987
+ hop_length=128,
988
+ power=1.0,
989
+ center=False,
990
+ pad_mode="reflect",
991
+ onesided=True,
992
+ )
993
+ self.assertEqual(spec_list[0].shape, (201, 729))
994
+ self.assertEqual(spec_list[1].shape, (201, 599))
995
+ self.assertEqual(spec_list[2].shape, (201, 1558))
996
+
997
+ spec_list = spectrogram_batch(
998
+ waveform_list,
999
+ window_function(400, "hann"),
1000
+ frame_length=400,
1001
+ hop_length=128,
1002
+ fft_length=512,
1003
+ power=1.0,
1004
+ center=True,
1005
+ pad_mode="reflect",
1006
+ onesided=True,
1007
+ )
1008
+ self.assertEqual(spec_list[0].shape, (257, 732))
1009
+ self.assertEqual(spec_list[1].shape, (257, 602))
1010
+ self.assertEqual(spec_list[2].shape, (257, 1561))
1011
+
1012
+ spec_list = spectrogram_batch(
1013
+ waveform_list,
1014
+ window_function(400, "hann", frame_length=512),
1015
+ frame_length=512,
1016
+ hop_length=64,
1017
+ power=1.0,
1018
+ center=True,
1019
+ pad_mode="reflect",
1020
+ onesided=False,
1021
+ )
1022
+ self.assertEqual(spec_list[0].shape, (512, 1464))
1023
+ self.assertEqual(spec_list[1].shape, (512, 1204))
1024
+ self.assertEqual(spec_list[2].shape, (512, 3122))
1025
+
1026
+ spec_list = spectrogram_batch(
1027
+ waveform_list,
1028
+ window_function(512, "hann"),
1029
+ frame_length=512,
1030
+ hop_length=64,
1031
+ power=1.0,
1032
+ center=True,
1033
+ pad_mode="reflect",
1034
+ onesided=False,
1035
+ )
1036
+ self.assertEqual(spec_list[0].shape, (512, 1464))
1037
+ self.assertEqual(spec_list[1].shape, (512, 1204))
1038
+ self.assertEqual(spec_list[2].shape, (512, 3122))
1039
+
1040
+ spec_list = spectrogram_batch(
1041
+ waveform_list,
1042
+ window_function(512, "hann"),
1043
+ frame_length=512,
1044
+ hop_length=512,
1045
+ power=1.0,
1046
+ center=True,
1047
+ pad_mode="reflect",
1048
+ onesided=False,
1049
+ )
1050
+ self.assertEqual(spec_list[0].shape, (512, 183))
1051
+ self.assertEqual(spec_list[1].shape, (512, 151))
1052
+ self.assertEqual(spec_list[2].shape, (512, 391))
1053
+
1054
+ def test_mel_spectrogram(self):
1055
+ waveform = self._load_datasamples(1)[0]
1056
+
1057
+ mel_filters = mel_filter_bank(
1058
+ num_frequency_bins=513,
1059
+ num_mel_filters=13,
1060
+ min_frequency=100,
1061
+ max_frequency=4000,
1062
+ sampling_rate=16000,
1063
+ norm=None,
1064
+ mel_scale="htk",
1065
+ )
1066
+ self.assertEqual(mel_filters.shape, (513, 13))
1067
+
1068
+ spec = spectrogram(
1069
+ waveform,
1070
+ window_function(800, "hann", frame_length=1024),
1071
+ frame_length=1024,
1072
+ hop_length=128,
1073
+ power=2.0,
1074
+ )
1075
+ self.assertEqual(spec.shape, (513, 732))
1076
+
1077
+ spec = spectrogram(
1078
+ waveform,
1079
+ window_function(800, "hann", frame_length=1024),
1080
+ frame_length=1024,
1081
+ hop_length=128,
1082
+ power=2.0,
1083
+ mel_filters=mel_filters,
1084
+ )
1085
+ self.assertEqual(spec.shape, (13, 732))
1086
+
1087
+ # fmt: off
1088
+ expected = np.array([
1089
+ 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,
1090
+ 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,
1091
+ 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,
1092
+ 9.44153646e-04
1093
+ ])
1094
+ # fmt: on
1095
+ self.assertTrue(np.allclose(spec[:, 300], expected))
1096
+
1097
+ def test_mel_spectrogram_batch(self):
1098
+ waveform_list = self._load_datasamples(3)
1099
+
1100
+ mel_filters = mel_filter_bank(
1101
+ num_frequency_bins=513,
1102
+ num_mel_filters=13,
1103
+ min_frequency=100,
1104
+ max_frequency=4000,
1105
+ sampling_rate=16000,
1106
+ norm=None,
1107
+ mel_scale="htk",
1108
+ )
1109
+ self.assertEqual(mel_filters.shape, (513, 13))
1110
+
1111
+ spec_list = spectrogram_batch(
1112
+ waveform_list,
1113
+ window_function(800, "hann", frame_length=1024),
1114
+ frame_length=1024,
1115
+ hop_length=128,
1116
+ power=2.0,
1117
+ )
1118
+ self.assertEqual(spec_list[0].shape, (513, 732))
1119
+ self.assertEqual(spec_list[1].shape, (513, 602))
1120
+ self.assertEqual(spec_list[2].shape, (513, 1561))
1121
+
1122
+ spec_list = spectrogram_batch(
1123
+ waveform_list,
1124
+ window_function(800, "hann", frame_length=1024),
1125
+ frame_length=1024,
1126
+ hop_length=128,
1127
+ power=2.0,
1128
+ mel_filters=mel_filters,
1129
+ )
1130
+ self.assertEqual(spec_list[0].shape, (13, 732))
1131
+ self.assertEqual(spec_list[1].shape, (13, 602))
1132
+ self.assertEqual(spec_list[2].shape, (13, 1561))
1133
+
1134
+ # fmt: off
1135
+ expected1 = np.array([
1136
+ 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,
1137
+ 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,
1138
+ 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,
1139
+ 9.44153646e-04
1140
+ ])
1141
+ expected2 = np.array([
1142
+ 71.82577165, 109.44693334, 272.4834194, 164.90450355,
1143
+ 16.54056349, 11.60810547, 24.87525946, 21.07317022,
1144
+ 1.26736284, 1.4583074, 1.36659061, 1.76305768,
1145
+ 2.03703503
1146
+ ])
1147
+ expected3 = np.array([
1148
+ 5.22246749e+02, 6.92660728e+02, 2.65895922e+02, 2.06526565e+01,
1149
+ 2.28692104e+00, 1.19473622e+00, 8.43228216e-01, 3.20760592e+00,
1150
+ 1.33654151e+00, 1.51050684e-01, 2.78282477e-01, 9.25020981e-01,
1151
+ 2.29908841e-01
1152
+ ])
1153
+ # fmt: on
1154
+ self.assertTrue(np.allclose(spec_list[0][:, 300], expected1))
1155
+ self.assertTrue(np.allclose(spec_list[1][:, 300], expected2))
1156
+ self.assertTrue(np.allclose(spec_list[2][:, 300], expected3))
1157
+
1158
+ def test_spectrogram_power(self):
1159
+ waveform = self._load_datasamples(1)[0]
1160
+
1161
+ spec = spectrogram(
1162
+ waveform,
1163
+ window_function(400, "hann", frame_length=512),
1164
+ frame_length=512,
1165
+ hop_length=128,
1166
+ power=None,
1167
+ )
1168
+ self.assertEqual(spec.shape, (257, 732))
1169
+ self.assertEqual(spec.dtype, np.complex64)
1170
+
1171
+ # fmt: off
1172
+ expected = np.array([
1173
+ 0.01452305+0.01820039j, -0.01737362-0.01641946j,
1174
+ 0.0121028 +0.01565081j, -0.02794554-0.03021514j,
1175
+ 0.04719803+0.04086519j, -0.04391563-0.02779365j,
1176
+ 0.05682834+0.01571325j, -0.08604821-0.02023657j,
1177
+ 0.07497991+0.0186641j , -0.06366091-0.00922475j,
1178
+ 0.11003416+0.0114788j , -0.13677941-0.01523552j,
1179
+ 0.10934535-0.00117226j, -0.11635598+0.02551187j,
1180
+ 0.14708674-0.03469823j, -0.1328196 +0.06034218j,
1181
+ 0.12667368-0.13973421j, -0.14764774+0.18912019j,
1182
+ 0.10235471-0.12181523j, -0.00773012+0.04730498j,
1183
+ -0.01487191-0.07312611j, -0.02739162+0.09619419j,
1184
+ 0.02895459-0.05398273j, 0.01198589+0.05276592j,
1185
+ -0.02117299-0.10123465j, 0.00666388+0.09526499j,
1186
+ -0.01672773-0.05649684j, 0.02723125+0.05939891j,
1187
+ -0.01879361-0.062954j , 0.03686557+0.04568823j,
1188
+ -0.07394181-0.07949649j, 0.06238583+0.13905765j,
1189
+ ])
1190
+ # fmt: on
1191
+ self.assertTrue(np.allclose(spec[64:96, 321], expected))
1192
+
1193
+ spec = spectrogram(
1194
+ waveform,
1195
+ window_function(400, "hann", frame_length=512),
1196
+ frame_length=512,
1197
+ hop_length=128,
1198
+ power=1.0,
1199
+ )
1200
+ self.assertEqual(spec.shape, (257, 732))
1201
+ self.assertEqual(spec.dtype, np.float64)
1202
+
1203
+ # fmt: off
1204
+ expected = np.array([
1205
+ 0.02328461, 0.02390484, 0.01978448, 0.04115711, 0.0624309 ,
1206
+ 0.05197181, 0.05896072, 0.08839577, 0.07726794, 0.06432579,
1207
+ 0.11063128, 0.13762532, 0.10935163, 0.11911998, 0.15112405,
1208
+ 0.14588428, 0.18860507, 0.23992978, 0.15910825, 0.04793241,
1209
+ 0.07462307, 0.10001811, 0.06125769, 0.05411011, 0.10342509,
1210
+ 0.09549777, 0.05892122, 0.06534349, 0.06569936, 0.05870678,
1211
+ 0.10856833, 0.1524107 , 0.11463385, 0.05766969, 0.12385171,
1212
+ 0.14472842, 0.11978184, 0.10353675, 0.07244056, 0.03461861,
1213
+ 0.02624896, 0.02227475, 0.01238363, 0.00885281, 0.0110049 ,
1214
+ 0.00807005, 0.01033663, 0.01703181, 0.01445856, 0.00585615,
1215
+ 0.0132431 , 0.02754132, 0.01524478, 0.0204908 , 0.07453328,
1216
+ 0.10716327, 0.07195779, 0.08816078, 0.18340898, 0.16449876,
1217
+ 0.12322842, 0.1621659 , 0.12334293, 0.06033659,
1218
+ ])
1219
+ # fmt: on
1220
+ self.assertTrue(np.allclose(spec[64:128, 321], expected))
1221
+
1222
+ spec = spectrogram(
1223
+ waveform,
1224
+ window_function(400, "hann", frame_length=512),
1225
+ frame_length=512,
1226
+ hop_length=128,
1227
+ power=2.0,
1228
+ )
1229
+ self.assertEqual(spec.shape, (257, 732))
1230
+ self.assertEqual(spec.dtype, np.float64)
1231
+
1232
+ # fmt: off
1233
+ expected = np.array([
1234
+ 5.42173162e-04, 5.71441371e-04, 3.91425507e-04, 1.69390778e-03,
1235
+ 3.89761780e-03, 2.70106923e-03, 3.47636663e-03, 7.81381316e-03,
1236
+ 5.97033510e-03, 4.13780799e-03, 1.22392802e-02, 1.89407300e-02,
1237
+ 1.19577805e-02, 1.41895693e-02, 2.28384770e-02, 2.12822221e-02,
1238
+ 3.55718732e-02, 5.75663000e-02, 2.53154356e-02, 2.29751552e-03,
1239
+ 5.56860259e-03, 1.00036217e-02, 3.75250424e-03, 2.92790355e-03,
1240
+ 1.06967501e-02, 9.11982451e-03, 3.47171025e-03, 4.26977174e-03,
1241
+ 4.31640586e-03, 3.44648538e-03, 1.17870830e-02, 2.32290216e-02,
1242
+ 1.31409196e-02, 3.32579296e-03, 1.53392460e-02, 2.09463164e-02,
1243
+ 1.43476883e-02, 1.07198600e-02, 5.24763530e-03, 1.19844836e-03,
1244
+ 6.89007982e-04, 4.96164430e-04, 1.53354369e-04, 7.83722571e-05,
1245
+ 1.21107812e-04, 6.51257360e-05, 1.06845939e-04, 2.90082477e-04,
1246
+ 2.09049831e-04, 3.42945241e-05, 1.75379610e-04, 7.58524227e-04,
1247
+ 2.32403356e-04, 4.19872697e-04, 5.55520924e-03, 1.14839673e-02,
1248
+ 5.17792348e-03, 7.77232368e-03, 3.36388536e-02, 2.70598419e-02,
1249
+ 1.51852425e-02, 2.62977779e-02, 1.52134784e-02, 3.64050455e-03,
1250
+ ])
1251
+ # fmt: on
1252
+ self.assertTrue(np.allclose(spec[64:128, 321], expected))
1253
+
1254
+ def test_spectrogram_batch_power(self):
1255
+ waveform_list = self._load_datasamples(3)
1256
+
1257
+ spec_list = spectrogram_batch(
1258
+ waveform_list,
1259
+ window_function(400, "hann", frame_length=512),
1260
+ frame_length=512,
1261
+ hop_length=128,
1262
+ power=None,
1263
+ )
1264
+ self.assertEqual(spec_list[0].shape, (257, 732))
1265
+ self.assertEqual(spec_list[0].dtype, np.complex64)
1266
+ self.assertEqual(spec_list[1].shape, (257, 602))
1267
+ self.assertEqual(spec_list[1].dtype, np.complex64)
1268
+ self.assertEqual(spec_list[2].shape, (257, 1561))
1269
+ self.assertEqual(spec_list[2].dtype, np.complex64)
1270
+
1271
+ # fmt: off
1272
+ expected1 = np.array([
1273
+ 0.01452305+0.01820039j, -0.01737362-0.01641946j,
1274
+ 0.0121028 +0.01565081j, -0.02794554-0.03021514j,
1275
+ 0.04719803+0.04086519j, -0.04391563-0.02779365j,
1276
+ 0.05682834+0.01571325j, -0.08604821-0.02023657j,
1277
+ 0.07497991+0.0186641j , -0.06366091-0.00922475j,
1278
+ 0.11003416+0.0114788j , -0.13677941-0.01523552j,
1279
+ 0.10934535-0.00117226j, -0.11635598+0.02551187j,
1280
+ 0.14708674-0.03469823j, -0.1328196 +0.06034218j,
1281
+ 0.12667368-0.13973421j, -0.14764774+0.18912019j,
1282
+ 0.10235471-0.12181523j, -0.00773012+0.04730498j,
1283
+ -0.01487191-0.07312611j, -0.02739162+0.09619419j,
1284
+ 0.02895459-0.05398273j, 0.01198589+0.05276592j,
1285
+ -0.02117299-0.10123465j, 0.00666388+0.09526499j,
1286
+ -0.01672773-0.05649684j, 0.02723125+0.05939891j,
1287
+ -0.01879361-0.062954j , 0.03686557+0.04568823j,
1288
+ -0.07394181-0.07949649j, 0.06238583+0.13905765j,
1289
+ ])
1290
+ expected2 = np.array([
1291
+ -0.01634146-7.0067253e-03j, -0.00068403+9.2661660e-03j,
1292
+ 0.00571721-3.9035487e-03j, -0.00915086+1.5033451e-03j,
1293
+ 0.01138636+5.4256055e-03j, -0.00294282-1.2016168e-02j,
1294
+ -0.00428711+7.3687937e-03j, -0.001002 -1.3972387e-03j,
1295
+ 0.00622582+3.7551194e-03j, -0.00137886-7.0342086e-03j,
1296
+ -0.00824075+3.8430823e-03j, 0.0107349 +7.1450039e-03j,
1297
+ 0.00363763-1.4242286e-02j, -0.01499857+1.7917662e-05j,
1298
+ -0.0046242 +1.2500680e-02j, 0.02180984+7.2047939e-03j,
1299
+ -0.00273568-1.6844695e-02j, -0.00178986-7.5209686e-03j,
1300
+ -0.01661806+1.2662713e-03j, -0.01045276+2.0611197e-02j,
1301
+ 0.03252975+2.5592113e-02j, 0.03945662-6.7136563e-02j,
1302
+ -0.10622615+4.9393820e-03j, 0.06684612+6.4607985e-02j,
1303
+ -0.00753762-5.1637031e-02j, -0.00220644+1.8002450e-02j,
1304
+ -0.00357443-4.1291970e-03j, 0.01463647-1.4063751e-03j,
1305
+ -0.02252573-1.1189026e-02j, 0.00276293+1.9019062e-02j,
1306
+ 0.01216721+1.2095908e-03j, 0.00034753-7.4386634e-03j
1307
+ ])
1308
+ expected3 = np.array([
1309
+ 2.3276670e-02+0.0406534j, -2.4413882e-02-0.07868771j,
1310
+ 1.0993068e-02+0.05550544j, -1.5825305e-02+0.00480187j,
1311
+ 4.7617555e-02-0.04421869j, -7.1669750e-02+0.06317082j,
1312
+ 5.9706111e-02-0.08369736j, -2.2317577e-02+0.08915959j,
1313
+ -2.3291381e-02-0.06601578j, 5.9362967e-02+0.03185856j,
1314
+ -6.5269925e-02+0.0030586j, 5.0898481e-02-0.04319243j,
1315
+ -4.0413942e-02+0.08051146j, 3.0059000e-02-0.09730332j,
1316
+ -1.2479190e-02+0.09703682j, -6.1806822e-03-0.09617531j,
1317
+ 2.6907364e-02+0.08084074j, -4.1639723e-02-0.03391053j,
1318
+ 3.1113219e-02-0.01497662j, 3.4023849e-03+0.03632669j,
1319
+ -4.9804080e-02-0.039231j, 8.9777440e-02+0.02577243j,
1320
+ -9.2947647e-02+0.01514865j, 6.2368069e-02-0.05954866j,
1321
+ -2.9966677e-02+0.06520324j, -8.2365885e-05-0.0440613j ,
1322
+ 2.0203773e-02+0.04350767j, -8.9924788e-04-0.05406843j,
1323
+ -3.5951469e-02+0.03055602j, 3.3790238e-02+0.02182594j,
1324
+ 1.0919777e-03-0.06437822j, -1.8534327e-02+0.07866792j
1325
+ ])
1326
+ # fmt: on
1327
+ self.assertTrue(np.allclose(spec_list[0][64:96, 321], expected1))
1328
+ self.assertTrue(np.allclose(spec_list[1][64:96, 321], expected2))
1329
+ self.assertTrue(np.allclose(spec_list[2][64:96, 321], expected3))
1330
+
1331
+ spec_list = spectrogram_batch(
1332
+ waveform_list,
1333
+ window_function(400, "hann", frame_length=512),
1334
+ frame_length=512,
1335
+ hop_length=128,
1336
+ power=1.0,
1337
+ )
1338
+ self.assertEqual(spec_list[0].shape, (257, 732))
1339
+ self.assertEqual(spec_list[0].dtype, np.float64)
1340
+ self.assertEqual(spec_list[1].shape, (257, 602))
1341
+ self.assertEqual(spec_list[1].dtype, np.float64)
1342
+ self.assertEqual(spec_list[2].shape, (257, 1561))
1343
+ self.assertEqual(spec_list[2].dtype, np.float64)
1344
+
1345
+ # fmt: off
1346
+ expected1 = np.array([
1347
+ 0.02328461, 0.02390484, 0.01978448, 0.04115711, 0.0624309 ,
1348
+ 0.05197181, 0.05896072, 0.08839577, 0.07726794, 0.06432579,
1349
+ 0.11063128, 0.13762532, 0.10935163, 0.11911998, 0.15112405,
1350
+ 0.14588428, 0.18860507, 0.23992978, 0.15910825, 0.04793241,
1351
+ 0.07462307, 0.10001811, 0.06125769, 0.05411011, 0.10342509,
1352
+ 0.09549777, 0.05892122, 0.06534349, 0.06569936, 0.05870678,
1353
+ 0.10856833, 0.1524107 , 0.11463385, 0.05766969, 0.12385171,
1354
+ 0.14472842, 0.11978184, 0.10353675, 0.07244056, 0.03461861,
1355
+ 0.02624896, 0.02227475, 0.01238363, 0.00885281, 0.0110049 ,
1356
+ 0.00807005, 0.01033663, 0.01703181, 0.01445856, 0.00585615,
1357
+ 0.0132431 , 0.02754132, 0.01524478, 0.0204908 , 0.07453328,
1358
+ 0.10716327, 0.07195779, 0.08816078, 0.18340898, 0.16449876,
1359
+ 0.12322842, 0.1621659 , 0.12334293, 0.06033659,
1360
+ ])
1361
+ expected2 = np.array([
1362
+ 0.01778026, 0.00929138, 0.00692273, 0.00927352, 0.01261294,
1363
+ 0.01237128, 0.00852516, 0.00171938, 0.00727061, 0.00716808,
1364
+ 0.00909281, 0.01289532, 0.01469949, 0.01499858, 0.01332855,
1365
+ 0.02296907, 0.01706539, 0.00773101, 0.01666623, 0.02311021,
1366
+ 0.0413901, 0.07787261, 0.10634092, 0.09296556, 0.05218428,
1367
+ 0.01813716, 0.00546139, 0.01470388, 0.02515159, 0.0192187,
1368
+ 0.01222719, 0.00744678, 0.01045674, 0.01923522, 0.01990819,
1369
+ 0.01174323, 0.01535391, 0.02786647, 0.02904595, 0.0313408 ,
1370
+ 0.0340503, 0.03118268, 0.02915136, 0.04200513, 0.05563153,
1371
+ 0.05429446, 0.05021769, 0.05882667, 0.06668596, 0.06555867,
1372
+ 0.04523559, 0.01489498, 0.01031892, 0.02134155, 0.01736669,
1373
+ 0.0195216, 0.03971575, 0.03938636, 0.02052712, 0.03104931,
1374
+ 0.0902727, 0.09022622, 0.03275532, 0.0172633,
1375
+ ])
1376
+ expected3 = np.array([
1377
+ 0.04684551, 0.08238806, 0.05658358, 0.01653778, 0.06498249,
1378
+ 0.09553589, 0.10281084, 0.09191031, 0.07000408, 0.06737158,
1379
+ 0.06534155, 0.06675509, 0.09008541, 0.10184046, 0.09783596,
1380
+ 0.0963737, 0.08520112, 0.05370093, 0.03453015, 0.03648568,
1381
+ 0.06339967, 0.09340346, 0.09417402, 0.08623119, 0.07175977,
1382
+ 0.04406138, 0.04796988, 0.05407591, 0.0471824 , 0.04022626,
1383
+ 0.06438748, 0.0808218, 0.0745263, 0.06191467, 0.03116328,
1384
+ 0.03206497, 0.05867718, 0.04424652, 0.04448404, 0.07032498,
1385
+ 0.08300796, 0.07895744, 0.0816894, 0.09392357, 0.07571699,
1386
+ 0.03967651, 0.07703795, 0.06464871, 0.08704693, 0.14085226,
1387
+ 0.1350321, 0.18794712, 0.27043005, 0.26596246, 0.19948336,
1388
+ 0.06545141, 0.13204652, 0.08554521, 0.2262849, 0.33900721,
1389
+ 0.3970475, 0.3482436, 0.17134947, 0.46249565,
1390
+ ])
1391
+ # fmt: on
1392
+ self.assertTrue(np.allclose(spec_list[0][64:128, 321], expected1))
1393
+ self.assertTrue(np.allclose(spec_list[1][64:128, 321], expected2))
1394
+ self.assertTrue(np.allclose(spec_list[2][64:128, 321], expected3))
1395
+
1396
+ spec_list = spectrogram_batch(
1397
+ waveform_list,
1398
+ window_function(400, "hann", frame_length=512),
1399
+ frame_length=512,
1400
+ hop_length=128,
1401
+ power=2.0,
1402
+ )
1403
+ self.assertEqual(spec_list[0].shape, (257, 732))
1404
+ self.assertEqual(spec_list[0].dtype, np.float64)
1405
+ self.assertEqual(spec_list[1].shape, (257, 602))
1406
+ self.assertEqual(spec_list[1].dtype, np.float64)
1407
+ self.assertEqual(spec_list[2].shape, (257, 1561))
1408
+ self.assertEqual(spec_list[2].dtype, np.float64)
1409
+
1410
+ # fmt: off
1411
+ expected1 = np.array([
1412
+ 5.42173162e-04, 5.71441371e-04, 3.91425507e-04, 1.69390778e-03,
1413
+ 3.89761780e-03, 2.70106923e-03, 3.47636663e-03, 7.81381316e-03,
1414
+ 5.97033510e-03, 4.13780799e-03, 1.22392802e-02, 1.89407300e-02,
1415
+ 1.19577805e-02, 1.41895693e-02, 2.28384770e-02, 2.12822221e-02,
1416
+ 3.55718732e-02, 5.75663000e-02, 2.53154356e-02, 2.29751552e-03,
1417
+ 5.56860259e-03, 1.00036217e-02, 3.75250424e-03, 2.92790355e-03,
1418
+ 1.06967501e-02, 9.11982451e-03, 3.47171025e-03, 4.26977174e-03,
1419
+ 4.31640586e-03, 3.44648538e-03, 1.17870830e-02, 2.32290216e-02,
1420
+ 1.31409196e-02, 3.32579296e-03, 1.53392460e-02, 2.09463164e-02,
1421
+ 1.43476883e-02, 1.07198600e-02, 5.24763530e-03, 1.19844836e-03,
1422
+ 6.89007982e-04, 4.96164430e-04, 1.53354369e-04, 7.83722571e-05,
1423
+ 1.21107812e-04, 6.51257360e-05, 1.06845939e-04, 2.90082477e-04,
1424
+ 2.09049831e-04, 3.42945241e-05, 1.75379610e-04, 7.58524227e-04,
1425
+ 2.32403356e-04, 4.19872697e-04, 5.55520924e-03, 1.14839673e-02,
1426
+ 5.17792348e-03, 7.77232368e-03, 3.36388536e-02, 2.70598419e-02,
1427
+ 1.51852425e-02, 2.62977779e-02, 1.52134784e-02, 3.64050455e-03,
1428
+ ])
1429
+ expected2 = np.array([
1430
+ 3.16137604e-04, 8.63297362e-05, 4.79241720e-05, 8.59982493e-05,
1431
+ 1.59086326e-04, 1.53048476e-04, 7.26783945e-05, 2.95627100e-06,
1432
+ 5.28617352e-05, 5.13813355e-05, 8.26792588e-05, 1.66289156e-04,
1433
+ 2.16075069e-04, 2.24957314e-04, 1.77650211e-04, 5.27578282e-04,
1434
+ 2.91227688e-04, 5.97685493e-05, 2.77763360e-04, 5.34081651e-04,
1435
+ 1.71314057e-03, 6.06414277e-03, 1.13083916e-02, 8.64259617e-03,
1436
+ 2.72319867e-03, 3.28956593e-04, 2.98268126e-05, 2.16204145e-04,
1437
+ 6.32602626e-04, 3.69358508e-04, 1.49504171e-04, 5.54544917e-05,
1438
+ 1.09343371e-04, 3.69993847e-04, 3.96335839e-04, 1.37903521e-04,
1439
+ 2.35742483e-04, 7.76540114e-04, 8.43667068e-04, 9.82245923e-04,
1440
+ 1.15942286e-03, 9.72359636e-04, 8.49801853e-04, 1.76443092e-03,
1441
+ 3.09486753e-03, 2.94788822e-03, 2.52181630e-03, 3.46057723e-03,
1442
+ 4.44701769e-03, 4.29793858e-03, 2.04625858e-03, 2.21860290e-04,
1443
+ 1.06480179e-04, 4.55461892e-04, 3.01601836e-04, 3.81092892e-04,
1444
+ 1.57734053e-03, 1.55128531e-03, 4.21362677e-04, 9.64059883e-04,
1445
+ 8.14916019e-03, 8.14077014e-03, 1.07291131e-03, 2.98021545e-04,
1446
+ ])
1447
+ expected3 = np.array([
1448
+ 0.0021945 , 0.00678779, 0.0032017 , 0.0002735 , 0.00422272,
1449
+ 0.00912711, 0.01057007, 0.00844751, 0.00490057, 0.00453893,
1450
+ 0.00426952, 0.00445624, 0.00811538, 0.01037148, 0.00957188,
1451
+ 0.00928789, 0.00725923, 0.00288379, 0.00119233, 0.0013312 ,
1452
+ 0.00401952, 0.00872421, 0.00886875, 0.00743582, 0.00514946,
1453
+ 0.00194141, 0.00230111, 0.0029242 , 0.00222618, 0.00161815,
1454
+ 0.00414575, 0.00653216, 0.00555417, 0.00383343, 0.00097115,
1455
+ 0.00102816, 0.00344301, 0.00195775, 0.00197883, 0.0049456 ,
1456
+ 0.00689032, 0.00623428, 0.00667316, 0.00882164, 0.00573306,
1457
+ 0.00157423, 0.00593485, 0.00417946, 0.00757717, 0.01983936,
1458
+ 0.01823367, 0.03532412, 0.07313241, 0.07073603, 0.03979361,
1459
+ 0.00428389, 0.01743628, 0.00731798, 0.05120486, 0.11492589,
1460
+ 0.15764671, 0.1212736 , 0.02936064, 0.21390222
1461
+ ])
1462
+ # fmt: on
1463
+ self.assertTrue(np.allclose(spec_list[0][64:128, 321], expected1))
1464
+ self.assertTrue(np.allclose(spec_list[1][64:128, 321], expected2))
1465
+ self.assertTrue(np.allclose(spec_list[2][64:128, 321], expected3))
1466
+
1467
+ def test_power_to_db(self):
1468
+ spectrogram = np.zeros((2, 3))
1469
+ spectrogram[0, 0] = 2.0
1470
+ spectrogram[0, 1] = 0.5
1471
+ spectrogram[0, 2] = 0.707
1472
+ spectrogram[1, 1] = 1.0
1473
+
1474
+ output = power_to_db(spectrogram, reference=1.0)
1475
+ expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-100.0, 0.0, -100.0]])
1476
+ self.assertTrue(np.allclose(output, expected))
1477
+
1478
+ output = power_to_db(spectrogram, reference=2.0)
1479
+ expected = np.array([[0.0, -6.02059991, -4.51610582], [-103.01029996, -3.01029996, -103.01029996]])
1480
+ self.assertTrue(np.allclose(output, expected))
1481
+
1482
+ output = power_to_db(spectrogram, min_value=1e-6)
1483
+ expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-60.0, 0.0, -60.0]])
1484
+ self.assertTrue(np.allclose(output, expected))
1485
+
1486
+ output = power_to_db(spectrogram, db_range=80)
1487
+ expected = np.array([[3.01029996, -3.01029996, -1.50580586], [-76.98970004, 0.0, -76.98970004]])
1488
+ self.assertTrue(np.allclose(output, expected))
1489
+
1490
+ output = power_to_db(spectrogram, reference=2.0, db_range=80)
1491
+ expected = np.array([[0.0, -6.02059991, -4.51610582], [-80.0, -3.01029996, -80.0]])
1492
+ self.assertTrue(np.allclose(output, expected))
1493
+
1494
+ output = power_to_db(spectrogram, reference=2.0, min_value=1e-6, db_range=80)
1495
+ expected = np.array([[0.0, -6.02059991, -4.51610582], [-63.01029996, -3.01029996, -63.01029996]])
1496
+ self.assertTrue(np.allclose(output, expected))
1497
+
1498
+ with pytest.raises(ValueError):
1499
+ power_to_db(spectrogram, reference=0.0)
1500
+ with pytest.raises(ValueError):
1501
+ power_to_db(spectrogram, min_value=0.0)
1502
+ with pytest.raises(ValueError):
1503
+ power_to_db(spectrogram, db_range=-80)
1504
+
1505
+ def test_power_to_db_batch(self):
1506
+ # Setup a batch of spectrograms with varying values and lengths
1507
+ batch_spectrogram = np.zeros((3, 2, 3))
1508
+ batch_spectrogram[0, 0, 0] = 2.0
1509
+ batch_spectrogram[0, 0, 1] = 0.5
1510
+ batch_spectrogram[0, 0, 2] = 0.707
1511
+ batch_spectrogram[0, 1, 1] = 1.0
1512
+ batch_spectrogram[1, :, :2] = batch_spectrogram[0, :, :2] * 1.5
1513
+ batch_spectrogram[2, :, :1] = batch_spectrogram[0, :, :1] * 0.5
1514
+
1515
+ # Expected values computed by applying `power_to_db` iteratively
1516
+ output = power_to_db_batch(batch_spectrogram, reference=1.0)
1517
+ expected = np.array(
1518
+ [
1519
+ [[3.01029996, -3.01029996, -1.50580586], [-100, 0, -100]],
1520
+ [[4.77121255, -1.24938737, -100], [-100, 1.76091259, -100]],
1521
+ [[0, -100, -100], [-100, -100, -100]],
1522
+ ]
1523
+ )
1524
+ self.assertTrue(np.allclose(output, expected))
1525
+
1526
+ output = power_to_db_batch(batch_spectrogram, reference=2.0)
1527
+ expected = np.array(
1528
+ [
1529
+ [[0, -6.02059991, -4.51610582], [-103.01029996, -3.01029996, -103.01029996]],
1530
+ [[1.76091259, -4.25968732, -103.01029996], [-103.01029996, -1.24938737, -103.01029996]],
1531
+ [[-3.01029996, -103.01029996, -103.01029996], [-103.01029996, -103.01029996, -103.01029996]],
1532
+ ]
1533
+ )
1534
+ self.assertTrue(np.allclose(output, expected))
1535
+
1536
+ output = power_to_db_batch(batch_spectrogram, min_value=1e-6)
1537
+ expected = np.array(
1538
+ [
1539
+ [[3.01029996, -3.01029996, -1.50580586], [-60, 0, -60]],
1540
+ [[4.77121255, -1.24938737, -60], [-60, 1.76091259, -60]],
1541
+ [[0, -60, -60], [-60, -60, -60]],
1542
+ ]
1543
+ )
1544
+ self.assertTrue(np.allclose(output, expected))
1545
+
1546
+ output = power_to_db_batch(batch_spectrogram, db_range=80)
1547
+ expected = np.array(
1548
+ [
1549
+ [[3.01029996, -3.01029996, -1.50580586], [-76.98970004, 0, -76.98970004]],
1550
+ [[4.77121255, -1.24938737, -75.22878745], [-75.22878745, 1.76091259, -75.22878745]],
1551
+ [[0, -80, -80], [-80, -80, -80]],
1552
+ ]
1553
+ )
1554
+ self.assertTrue(np.allclose(output, expected))
1555
+
1556
+ output = power_to_db_batch(batch_spectrogram, reference=2.0, db_range=80)
1557
+ expected = np.array(
1558
+ [
1559
+ [[0, -6.02059991, -4.51610582], [-80, -3.01029996, -80]],
1560
+ [[1.76091259, -4.25968732, -78.23908741], [-78.23908741, -1.24938737, -78.23908741]],
1561
+ [[-3.01029996, -83.01029996, -83.01029996], [-83.01029996, -83.01029996, -83.01029996]],
1562
+ ]
1563
+ )
1564
+ self.assertTrue(np.allclose(output, expected))
1565
+
1566
+ output = power_to_db_batch(batch_spectrogram, reference=2.0, min_value=1e-6, db_range=80)
1567
+ expected = np.array(
1568
+ [
1569
+ [[0, -6.02059991, -4.51610582], [-63.01029996, -3.01029996, -63.01029996]],
1570
+ [[1.76091259, -4.25968732, -63.01029996], [-63.01029996, -1.24938737, -63.01029996]],
1571
+ [[-3.01029996, -63.01029996, -63.01029996], [-63.01029996, -63.01029996, -63.01029996]],
1572
+ ]
1573
+ )
1574
+ self.assertTrue(np.allclose(output, expected))
1575
+
1576
+ with pytest.raises(ValueError):
1577
+ power_to_db_batch(batch_spectrogram, reference=0.0)
1578
+ with pytest.raises(ValueError):
1579
+ power_to_db_batch(batch_spectrogram, min_value=0.0)
1580
+ with pytest.raises(ValueError):
1581
+ power_to_db_batch(batch_spectrogram, db_range=-80)
1582
+
1583
+ def test_amplitude_to_db(self):
1584
+ spectrogram = np.zeros((2, 3))
1585
+ spectrogram[0, 0] = 2.0
1586
+ spectrogram[0, 1] = 0.5
1587
+ spectrogram[0, 2] = 0.707
1588
+ spectrogram[1, 1] = 1.0
1589
+
1590
+ output = amplitude_to_db(spectrogram, reference=1.0)
1591
+ expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-100.0, 0.0, -100.0]])
1592
+ self.assertTrue(np.allclose(output, expected))
1593
+
1594
+ output = amplitude_to_db(spectrogram, reference=2.0)
1595
+ expected = np.array([[0.0, -12.04119983, -9.03221164], [-106.02059991, -6.02059991, -106.02059991]])
1596
+ self.assertTrue(np.allclose(output, expected))
1597
+
1598
+ output = amplitude_to_db(spectrogram, min_value=1e-3)
1599
+ expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-60.0, 0.0, -60.0]])
1600
+ self.assertTrue(np.allclose(output, expected))
1601
+
1602
+ output = amplitude_to_db(spectrogram, db_range=80)
1603
+ expected = np.array([[6.02059991, -6.02059991, -3.01161172], [-73.97940009, 0.0, -73.97940009]])
1604
+ self.assertTrue(np.allclose(output, expected))
1605
+
1606
+ output = amplitude_to_db(spectrogram, reference=2.0, db_range=80)
1607
+ expected = np.array([[0.0, -12.04119983, -9.03221164], [-80.0, -6.02059991, -80.0]])
1608
+ self.assertTrue(np.allclose(output, expected))
1609
+
1610
+ output = amplitude_to_db(spectrogram, reference=2.0, min_value=1e-3, db_range=80)
1611
+ expected = np.array([[0.0, -12.04119983, -9.03221164], [-66.02059991, -6.02059991, -66.02059991]])
1612
+ self.assertTrue(np.allclose(output, expected))
1613
+
1614
+ with pytest.raises(ValueError):
1615
+ amplitude_to_db(spectrogram, reference=0.0)
1616
+ with pytest.raises(ValueError):
1617
+ amplitude_to_db(spectrogram, min_value=0.0)
1618
+ with pytest.raises(ValueError):
1619
+ amplitude_to_db(spectrogram, db_range=-80)
1620
+
1621
+ def test_amplitude_to_db_batch(self):
1622
+ # Setup a batch of spectrograms with varying values and lengths
1623
+ batch_spectrogram = np.zeros((3, 2, 3))
1624
+ batch_spectrogram[0, 0, 0] = 2.0
1625
+ batch_spectrogram[0, 0, 1] = 0.5
1626
+ batch_spectrogram[0, 0, 2] = 0.707
1627
+ batch_spectrogram[0, 1, 1] = 1.0
1628
+ batch_spectrogram[1, :, :2] = batch_spectrogram[0, :, :2] * 1.5
1629
+ batch_spectrogram[2, :, :1] = batch_spectrogram[0, :, :1] * 0.5
1630
+
1631
+ # Expected values computed by applying `amplitude_to_db` iteratively
1632
+ output = amplitude_to_db_batch(batch_spectrogram, reference=1.0)
1633
+ expected = np.array(
1634
+ [
1635
+ [[6.02059991, -6.02059991, -3.01161172], [-100, 0, -100]],
1636
+ [[9.54242509, -2.49877473, -100], [-100, 3.52182518, -100]],
1637
+ [[0, -100, -100], [-100, -100, -100]],
1638
+ ]
1639
+ )
1640
+ self.assertTrue(np.allclose(output, expected))
1641
+
1642
+ output = amplitude_to_db_batch(batch_spectrogram, reference=2.0)
1643
+ expected = np.array(
1644
+ [
1645
+ [[0, -12.04119983, -9.03221164], [-106.02059991, -6.02059991, -106.02059991]],
1646
+ [[3.52182518, -8.51937465, -106.02059991], [-106.02059991, -2.49877473, -106.02059991]],
1647
+ [[-6.02059991, -106.02059991, -106.02059991], [-106.02059991, -106.02059991, -106.02059991]],
1648
+ ]
1649
+ )
1650
+ self.assertTrue(np.allclose(output, expected))
1651
+
1652
+ output = amplitude_to_db_batch(batch_spectrogram, min_value=1e-3)
1653
+ expected = np.array(
1654
+ [
1655
+ [[6.02059991, -6.02059991, -3.01161172], [-60, 0, -60]],
1656
+ [[9.54242509, -2.49877473, -60], [-60, 3.52182518, -60]],
1657
+ [[0, -60, -60], [-60, -60, -60]],
1658
+ ]
1659
+ )
1660
+ self.assertTrue(np.allclose(output, expected))
1661
+
1662
+ output = amplitude_to_db_batch(batch_spectrogram, db_range=80)
1663
+ expected = np.array(
1664
+ [
1665
+ [[6.02059991, -6.02059991, -3.01161172], [-73.97940009, 0, -73.97940009]],
1666
+ [[9.54242509, -2.49877473, -70.45757491], [-70.45757491, 3.52182518, -70.45757491]],
1667
+ [[0, -80, -80], [-80, -80, -80]],
1668
+ ]
1669
+ )
1670
+ self.assertTrue(np.allclose(output, expected))
1671
+
1672
+ output = amplitude_to_db_batch(batch_spectrogram, reference=2.0, db_range=80)
1673
+ expected = np.array(
1674
+ [
1675
+ [[0, -12.04119983, -9.03221164], [-80, -6.02059991, -80]],
1676
+ [[3.52182518, -8.51937465, -76.47817482], [-76.47817482, -2.49877473, -76.47817482]],
1677
+ [[-6.02059991, -86.02059991, -86.02059991], [-86.02059991, -86.02059991, -86.02059991]],
1678
+ ]
1679
+ )
1680
+ self.assertTrue(np.allclose(output, expected))
1681
+
1682
+ output = amplitude_to_db_batch(batch_spectrogram, reference=2.0, min_value=1e-3, db_range=80)
1683
+ expected = np.array(
1684
+ [
1685
+ [[0, -12.04119983, -9.03221164], [-66.02059991, -6.02059991, -66.02059991]],
1686
+ [[3.52182518, -8.51937465, -66.02059991], [-66.02059991, -2.49877473, -66.02059991]],
1687
+ [[-6.02059991, -66.02059991, -66.02059991], [-66.02059991, -66.02059991, -66.02059991]],
1688
+ ]
1689
+ )
1690
+ self.assertTrue(np.allclose(output, expected))
1691
+
1692
+ with pytest.raises(ValueError):
1693
+ amplitude_to_db_batch(batch_spectrogram, reference=0.0)
1694
+ with pytest.raises(ValueError):
1695
+ amplitude_to_db_batch(batch_spectrogram, min_value=0.0)
1696
+ with pytest.raises(ValueError):
1697
+ amplitude_to_db_batch(batch_spectrogram, db_range=-80)
1698
+
1699
+ @require_librosa
1700
+ def test_chroma_equivalence(self):
1701
+ num_frequency_bins = 25
1702
+ num_chroma = 6
1703
+ sampling_rate = 24000
1704
+
1705
+ # test default parameters
1706
+ original_chroma = chroma(sr=sampling_rate, n_chroma=num_chroma, n_fft=num_frequency_bins)
1707
+ utils_chroma = chroma_filter_bank(
1708
+ num_frequency_bins=num_frequency_bins, num_chroma=num_chroma, sampling_rate=sampling_rate
1709
+ )
1710
+
1711
+ self.assertTrue(np.allclose(original_chroma, utils_chroma))
1712
+
1713
+ # test no weighting_parameters
1714
+ original_chroma = chroma(sr=sampling_rate, n_chroma=num_chroma, n_fft=num_frequency_bins, octwidth=None)
1715
+ utils_chroma = chroma_filter_bank(
1716
+ num_frequency_bins=num_frequency_bins,
1717
+ num_chroma=num_chroma,
1718
+ sampling_rate=sampling_rate,
1719
+ weighting_parameters=None,
1720
+ )
1721
+
1722
+ self.assertTrue(np.allclose(original_chroma, utils_chroma))
1723
+
1724
+ # test with L1 norm
1725
+ original_chroma = chroma(sr=sampling_rate, n_chroma=num_chroma, n_fft=num_frequency_bins, norm=1.0)
1726
+ utils_chroma = chroma_filter_bank(
1727
+ num_frequency_bins=num_frequency_bins, num_chroma=num_chroma, sampling_rate=sampling_rate, power=1.0
1728
+ )
1729
+
1730
+ self.assertTrue(np.allclose(original_chroma, utils_chroma))
1731
+
1732
+ # test starting at 'A' chroma, power = None, tuning = 0, different weighting_parameters
1733
+ original_chroma = chroma(
1734
+ sr=sampling_rate,
1735
+ n_chroma=num_chroma,
1736
+ n_fft=num_frequency_bins,
1737
+ norm=None,
1738
+ base_c=None,
1739
+ octwidth=1.0,
1740
+ ctroct=4.0,
1741
+ )
1742
+ utils_chroma = chroma_filter_bank(
1743
+ num_frequency_bins=num_frequency_bins,
1744
+ num_chroma=num_chroma,
1745
+ sampling_rate=sampling_rate,
1746
+ power=None,
1747
+ start_at_c_chroma=False,
1748
+ weighting_parameters=(4.0, 1.0),
1749
+ )
1750
+
1751
+ self.assertTrue(np.allclose(original_chroma, utils_chroma))
docs/transformers/tests/utils/test_backbone_utils.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import pytest
18
+
19
+ from transformers import DetrConfig, MaskFormerConfig, ResNetBackbone, ResNetConfig, TimmBackbone
20
+ from transformers.testing_utils import require_torch, slow
21
+ from transformers.utils.backbone_utils import (
22
+ BackboneMixin,
23
+ get_aligned_output_features_output_indices,
24
+ load_backbone,
25
+ verify_out_features_out_indices,
26
+ )
27
+ from transformers.utils.import_utils import is_torch_available
28
+
29
+
30
+ if is_torch_available():
31
+ import torch
32
+
33
+ from transformers import BertPreTrainedModel
34
+
35
+
36
+ class BackboneUtilsTester(unittest.TestCase):
37
+ def test_get_aligned_output_features_output_indices(self):
38
+ stage_names = ["a", "b", "c"]
39
+
40
+ # Defaults to last layer if both are None
41
+ out_features, out_indices = get_aligned_output_features_output_indices(None, None, stage_names)
42
+ self.assertEqual(out_features, ["c"])
43
+ self.assertEqual(out_indices, [2])
44
+
45
+ # Out indices set to match out features
46
+ out_features, out_indices = get_aligned_output_features_output_indices(["a", "c"], None, stage_names)
47
+ self.assertEqual(out_features, ["a", "c"])
48
+ self.assertEqual(out_indices, [0, 2])
49
+
50
+ # Out features set to match out indices
51
+ out_features, out_indices = get_aligned_output_features_output_indices(None, [0, 2], stage_names)
52
+ self.assertEqual(out_features, ["a", "c"])
53
+ self.assertEqual(out_indices, [0, 2])
54
+
55
+ # Out features selected from negative indices
56
+ out_features, out_indices = get_aligned_output_features_output_indices(None, [-3, -1], stage_names)
57
+ self.assertEqual(out_features, ["a", "c"])
58
+ self.assertEqual(out_indices, [-3, -1])
59
+
60
+ def test_verify_out_features_out_indices(self):
61
+ # Stage names must be set
62
+ with pytest.raises(ValueError, match="Stage_names must be set for transformers backbones"):
63
+ verify_out_features_out_indices(["a", "b"], (0, 1), None)
64
+
65
+ # Out features must be a list
66
+ with pytest.raises(ValueError, match="out_features must be a list got <class 'tuple'>"):
67
+ verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"])
68
+
69
+ # Out features must be a subset of stage names
70
+ with pytest.raises(
71
+ ValueError, match=r"out_features must be a subset of stage_names: \['a'\] got \['a', 'b'\]"
72
+ ):
73
+ verify_out_features_out_indices(["a", "b"], [0, 1], ["a"])
74
+
75
+ # Out features must contain no duplicates
76
+ with pytest.raises(ValueError, match=r"out_features must not contain any duplicates, got \['a', 'a'\]"):
77
+ verify_out_features_out_indices(["a", "a"], None, ["a"])
78
+
79
+ # Out indices must be a list
80
+ with pytest.raises(ValueError, match="out_indices must be a list, got <class 'int'>"):
81
+ verify_out_features_out_indices(None, 0, ["a", "b"])
82
+
83
+ with pytest.raises(ValueError, match="out_indices must be a list, got <class 'tuple'>"):
84
+ verify_out_features_out_indices(None, (0, 1), ["a", "b"])
85
+
86
+ # Out indices must be a subset of stage names
87
+ with pytest.raises(
88
+ ValueError, match=r"out_indices must be valid indices for stage_names \['a'\], got \[0, 1\]"
89
+ ):
90
+ verify_out_features_out_indices(None, [0, 1], ["a"])
91
+
92
+ # Out indices must contain no duplicates
93
+ with pytest.raises(ValueError, match=r"out_indices must not contain any duplicates, got \[0, 0\]"):
94
+ verify_out_features_out_indices(None, [0, 0], ["a"])
95
+
96
+ # Out features and out indices must be the same length
97
+ with pytest.raises(
98
+ ValueError, match="out_features and out_indices should have the same length if both are set"
99
+ ):
100
+ verify_out_features_out_indices(["a", "b"], [0], ["a", "b", "c"])
101
+
102
+ # Out features should match out indices
103
+ with pytest.raises(
104
+ ValueError, match="out_features and out_indices should correspond to the same stages if both are set"
105
+ ):
106
+ verify_out_features_out_indices(["a", "b"], [0, 2], ["a", "b", "c"])
107
+
108
+ # Out features and out indices should be in order
109
+ with pytest.raises(
110
+ ValueError,
111
+ match=r"out_features must be in the same order as stage_names, expected \['a', 'b'\] got \['b', 'a'\]",
112
+ ):
113
+ verify_out_features_out_indices(["b", "a"], [0, 1], ["a", "b"])
114
+
115
+ with pytest.raises(
116
+ ValueError, match=r"out_indices must be in the same order as stage_names, expected \[-2, 1\] got \[1, -2\]"
117
+ ):
118
+ verify_out_features_out_indices(["a", "b"], [1, -2], ["a", "b"])
119
+
120
+ # Check passes with valid inputs
121
+ verify_out_features_out_indices(["a", "b", "d"], [0, 1, -1], ["a", "b", "c", "d"])
122
+
123
+ def test_backbone_mixin(self):
124
+ backbone = BackboneMixin()
125
+
126
+ backbone.stage_names = ["a", "b", "c"]
127
+ backbone._out_features = ["a", "c"]
128
+ backbone._out_indices = [0, 2]
129
+
130
+ # Check that the output features and indices are set correctly
131
+ self.assertEqual(backbone.out_features, ["a", "c"])
132
+ self.assertEqual(backbone.out_indices, [0, 2])
133
+
134
+ # Check out features and indices are updated correctly
135
+ backbone.out_features = ["a", "b"]
136
+ self.assertEqual(backbone.out_features, ["a", "b"])
137
+ self.assertEqual(backbone.out_indices, [0, 1])
138
+
139
+ backbone.out_indices = [-3, -1]
140
+ self.assertEqual(backbone.out_features, ["a", "c"])
141
+ self.assertEqual(backbone.out_indices, [-3, -1])
142
+
143
+ @slow
144
+ @require_torch
145
+ def test_load_backbone_from_config(self):
146
+ """
147
+ Test that load_backbone correctly loads a backbone from a backbone config.
148
+ """
149
+ config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2)))
150
+ backbone = load_backbone(config)
151
+ self.assertEqual(backbone.out_features, ["stem", "stage2"])
152
+ self.assertEqual(backbone.out_indices, (0, 2))
153
+ self.assertIsInstance(backbone, ResNetBackbone)
154
+
155
+ @slow
156
+ @require_torch
157
+ def test_load_backbone_from_checkpoint(self):
158
+ """
159
+ Test that load_backbone correctly loads a backbone from a checkpoint.
160
+ """
161
+ config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None)
162
+ backbone = load_backbone(config)
163
+ self.assertEqual(backbone.out_indices, [4])
164
+ self.assertEqual(backbone.out_features, ["stage4"])
165
+ self.assertIsInstance(backbone, ResNetBackbone)
166
+
167
+ config = MaskFormerConfig(
168
+ backbone="resnet18",
169
+ use_timm_backbone=True,
170
+ )
171
+ backbone = load_backbone(config)
172
+ # We can't know ahead of time the exact output features and indices, or the layer names before
173
+ # creating the timm model, so it defaults to the last layer (-1,) and has a different layer name
174
+ self.assertEqual(backbone.out_indices, (-1,))
175
+ self.assertEqual(backbone.out_features, ["layer4"])
176
+ self.assertIsInstance(backbone, TimmBackbone)
177
+
178
+ @slow
179
+ @require_torch
180
+ def test_load_backbone_backbone_kwargs(self):
181
+ """
182
+ Test that load_backbone correctly configures the loaded backbone with the provided kwargs.
183
+ """
184
+ config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)})
185
+ backbone = load_backbone(config)
186
+ self.assertEqual(backbone.out_indices, (0, 1))
187
+ self.assertIsInstance(backbone, TimmBackbone)
188
+
189
+ config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)})
190
+ backbone = load_backbone(config)
191
+ self.assertEqual(backbone.out_indices, (0, 2))
192
+ self.assertIsInstance(backbone, ResNetBackbone)
193
+
194
+ # Check can't be passed with a backone config
195
+ with pytest.raises(ValueError):
196
+ config = MaskFormerConfig(
197
+ backbone="microsoft/resnet-18",
198
+ backbone_config=ResNetConfig(out_indices=(0, 2)),
199
+ backbone_kwargs={"out_indices": (0, 1)},
200
+ )
201
+
202
+ @slow
203
+ @require_torch
204
+ def test_load_backbone_in_new_model(self):
205
+ """
206
+ Tests that new model can be created, with its weights instantiated and pretrained backbone weights loaded.
207
+ """
208
+
209
+ # Inherit from PreTrainedModel to ensure that the weights are initialized
210
+ class NewModel(BertPreTrainedModel):
211
+ def __init__(self, config):
212
+ super().__init__(config)
213
+ self.backbone = load_backbone(config)
214
+ self.layer_0 = torch.nn.Linear(config.hidden_size, config.hidden_size)
215
+ self.layer_1 = torch.nn.Linear(config.hidden_size, config.hidden_size)
216
+
217
+ def get_equal_not_equal_weights(model_0, model_1):
218
+ equal_weights = []
219
+ not_equal_weights = []
220
+ for (k0, v0), (k1, v1) in zip(model_0.named_parameters(), model_1.named_parameters()):
221
+ self.assertEqual(k0, k1)
222
+ weights_are_equal = torch.allclose(v0, v1)
223
+ if weights_are_equal:
224
+ equal_weights.append(k0)
225
+ else:
226
+ not_equal_weights.append(k0)
227
+ return equal_weights, not_equal_weights
228
+
229
+ config = MaskFormerConfig(use_pretrained_backbone=False, backbone="microsoft/resnet-18")
230
+ model_0 = NewModel(config)
231
+ model_1 = NewModel(config)
232
+ equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
233
+
234
+ # Norm layers are always initialized with the same weights
235
+ equal_weights = [w for w in equal_weights if "normalization" not in w]
236
+ self.assertEqual(len(equal_weights), 0)
237
+ self.assertEqual(len(not_equal_weights), 24)
238
+
239
+ # Now we create a new model with backbone weights that are pretrained
240
+ config.use_pretrained_backbone = True
241
+ model_0 = NewModel(config)
242
+ model_1 = NewModel(config)
243
+ equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
244
+
245
+ # Norm layers are always initialized with the same weights
246
+ equal_weights = [w for w in equal_weights if "normalization" not in w]
247
+ self.assertEqual(len(equal_weights), 20)
248
+ # Linear layers are still initialized randomly
249
+ self.assertEqual(len(not_equal_weights), 4)
250
+
251
+ # Check loading in timm backbone
252
+ config = DetrConfig(use_pretrained_backbone=False, backbone="resnet18", use_timm_backbone=True)
253
+ model_0 = NewModel(config)
254
+ model_1 = NewModel(config)
255
+ equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
256
+
257
+ # Norm layers are always initialized with the same weights
258
+ equal_weights = [w for w in equal_weights if "bn" not in w and "downsample.1" not in w]
259
+ self.assertEqual(len(equal_weights), 0)
260
+ self.assertEqual(len(not_equal_weights), 24)
261
+
262
+ # Now we create a new model with backbone weights that are pretrained
263
+ config.use_pretrained_backbone = True
264
+ model_0 = NewModel(config)
265
+ model_1 = NewModel(config)
266
+ equal_weights, not_equal_weights = get_equal_not_equal_weights(model_0, model_1)
267
+
268
+ # Norm layers are always initialized with the same weights
269
+ equal_weights = [w for w in equal_weights if "bn" not in w and "downsample.1" not in w]
270
+ self.assertEqual(len(equal_weights), 20)
271
+ # Linear layers are still initialized randomly
272
+ self.assertEqual(len(not_equal_weights), 4)
docs/transformers/tests/utils/test_cache_utils.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import unittest
17
+
18
+ from parameterized import parameterized
19
+
20
+ from transformers import set_seed
21
+ from transformers.testing_utils import (
22
+ CaptureStderr,
23
+ get_gpu_count,
24
+ is_torch_available,
25
+ require_gptq,
26
+ require_non_xpu,
27
+ require_read_token,
28
+ require_torch,
29
+ require_torch_accelerator,
30
+ require_torch_gpu,
31
+ require_torch_multi_gpu,
32
+ slow,
33
+ torch_device,
34
+ )
35
+
36
+
37
+ if is_torch_available():
38
+ import torch
39
+
40
+ from transformers import (
41
+ AutoModelForCausalLM,
42
+ AutoTokenizer,
43
+ ClvpForCausalLM,
44
+ DynamicCache,
45
+ GenerationConfig,
46
+ LlamaConfig,
47
+ SinkCache,
48
+ StaticCache,
49
+ convert_and_export_with_cache,
50
+ )
51
+ from transformers.utils import is_torch_greater_or_equal
52
+
53
+
54
+ @require_torch
55
+ class CacheTest(unittest.TestCase):
56
+ def test_dynamic_cache_retrocompatibility(self):
57
+ """Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
58
+ legacy_cache = ()
59
+ new_cache = DynamicCache()
60
+
61
+ # Creates a new cache with 10 layers in both formats
62
+ for layer_idx in range(10):
63
+ new_key = torch.rand((2, 4, 8, 16))
64
+ new_value = torch.rand((2, 4, 8, 16))
65
+ new_cache.update(new_key, new_value, layer_idx)
66
+ legacy_cache += ((new_key, new_value),)
67
+
68
+ # Sanity check 1: they must have the same shapes
69
+ self.assertTrue(len(legacy_cache), len(new_cache))
70
+ for layer_idx in range(10):
71
+ self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
72
+ for key_value_idx in range(2):
73
+ self.assertTrue(
74
+ legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
75
+ )
76
+
77
+ # Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
78
+ # expected value
79
+ self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
80
+
81
+ # Sanity check 3: they must be equal, and both support indexing
82
+ for layer_idx in range(10):
83
+ for key_value_idx in range(2):
84
+ self.assertTrue(
85
+ torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
86
+ )
87
+
88
+ # Test 1: We can convert from legacy to new with no changes
89
+ from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
90
+ for layer_idx in range(10):
91
+ for key_value_idx in range(2):
92
+ self.assertTrue(
93
+ torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
94
+ )
95
+
96
+ # Test 2: We can convert from new to legacy with no changes
97
+ to_legacy = new_cache.to_legacy_cache()
98
+ for layer_idx in range(10):
99
+ for key_value_idx in range(2):
100
+ self.assertTrue(
101
+ torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
102
+ )
103
+
104
+ def test_reorder_cache_retrocompatibility(self):
105
+ """Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
106
+ legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
107
+
108
+ legacy_cache = ()
109
+ new_cache = DynamicCache()
110
+
111
+ # Creates a new cache with 10 layers in both formats
112
+ for layer_idx in range(10):
113
+ new_key = torch.rand((4, 4, 8, 16))
114
+ new_value = torch.rand((4, 4, 8, 16))
115
+ new_cache.update(new_key, new_value, layer_idx)
116
+ legacy_cache += ((new_key, new_value),)
117
+
118
+ # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
119
+ # and batch_size=1
120
+ beam_idx = torch.randint(low=0, high=4, size=(4,))
121
+
122
+ legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
123
+ new_cache.reorder_cache(beam_idx)
124
+
125
+ # Let's check that the results are the same
126
+ for layer_idx in range(10):
127
+ for key_value_idx in range(2):
128
+ self.assertTrue(
129
+ torch.allclose(
130
+ new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
131
+ )
132
+ )
133
+
134
+ def test_static_cache_mha_mqa_gqa(self):
135
+ """
136
+ Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
137
+ attention (MQA)
138
+ """
139
+
140
+ def _random_kvs(config):
141
+ # shape for key and values: (batch_size, num_heads, seq_len, head_dim)
142
+ random_keys = torch.rand(
143
+ (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
144
+ device=torch_device,
145
+ )
146
+ random_values = torch.rand(
147
+ (1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
148
+ device=torch_device,
149
+ )
150
+ return random_keys, random_values
151
+
152
+ mha_config = LlamaConfig(num_attention_heads=32)
153
+ mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
154
+ cached_keys, cached_values = mha_static_cache.update(
155
+ *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
156
+ )
157
+ self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
158
+ self.assertTrue(cached_values.shape == (1, 32, 10, 128))
159
+
160
+ gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
161
+ gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
162
+ cached_keys, cached_values = gqa_static_cache.update(
163
+ *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
164
+ )
165
+ self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
166
+ self.assertTrue(cached_values.shape == (1, 4, 10, 128))
167
+
168
+ mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
169
+ mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
170
+ cached_keys, cached_values = mqa_static_cache.update(
171
+ *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)}
172
+ )
173
+ self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
174
+ self.assertTrue(cached_values.shape == (1, 1, 10, 128))
175
+
176
+ def test_dynamic_cache_exportability(self):
177
+ model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
178
+ model = model.eval()
179
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
180
+ prompt = "What is the best way to debug python script?"
181
+ inputs = tokenizer(prompt, return_tensors="pt")
182
+ attention_mask = inputs.attention_mask
183
+ input_ids = inputs.input_ids
184
+
185
+ past_key_values = DynamicCache()
186
+ ep = torch.export.export(
187
+ model,
188
+ (),
189
+ {
190
+ "input_ids": input_ids,
191
+ "attention_mask": attention_mask,
192
+ "past_key_values": past_key_values,
193
+ "use_cache": True,
194
+ },
195
+ strict=False,
196
+ )
197
+ res = ep.module()(
198
+ input_ids=input_ids,
199
+ attention_mask=attention_mask,
200
+ past_key_values=past_key_values,
201
+ use_cache=True,
202
+ )
203
+ self.assertTrue(len(res.past_key_values.key_cache) == model.config.num_hidden_layers)
204
+ self.assertEqual(2 * model.config.num_hidden_layers + 1, len(ep.graph_signature.output_specs))
205
+ self.assertEqual(
206
+ 3,
207
+ len(
208
+ [
209
+ x
210
+ for x in ep.graph_signature.input_specs
211
+ if x.kind == torch.export.graph_signature.InputKind.USER_INPUT
212
+ ]
213
+ ),
214
+ )
215
+
216
+ past_key_values_eager = DynamicCache()
217
+ res_eager = model(
218
+ input_ids=input_ids,
219
+ attention_mask=attention_mask,
220
+ past_key_values=past_key_values_eager,
221
+ use_cache=True,
222
+ )
223
+ self.assertTrue(torch.allclose(res.logits, res_eager.logits))
224
+ for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
225
+ self.assertTrue(torch.allclose(k1, k2))
226
+
227
+ for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
228
+ self.assertTrue(torch.allclose(v1, v2))
229
+
230
+ @slow
231
+ @require_read_token
232
+ def test_static_cache_exportability(self):
233
+ """
234
+ Tests that static cache works with `torch.export()`
235
+ """
236
+ if not is_torch_greater_or_equal("2.3"):
237
+ self.skipTest(reason="This test requires torch >= 2.3 to run.")
238
+
239
+ set_seed(0)
240
+ device = "cpu"
241
+ dtype = "bfloat16"
242
+ cache_implementation = "static"
243
+ attn_implementation = "sdpa" # Export and ExecuTorch only works for SdpaAttention
244
+ batch_size = 1
245
+ max_cache_len = 1234
246
+ model = AutoModelForCausalLM.from_pretrained(
247
+ "google/gemma-2b",
248
+ device_map=device,
249
+ torch_dtype=dtype,
250
+ attn_implementation=attn_implementation,
251
+ generation_config=GenerationConfig(
252
+ use_cache=True,
253
+ cache_implementation=cache_implementation,
254
+ max_length=max_cache_len,
255
+ cache_config={
256
+ "batch_size": batch_size,
257
+ "max_cache_len": max_cache_len,
258
+ "device": device,
259
+ },
260
+ ),
261
+ )
262
+ # Check if cache config is passed through correctly
263
+ self.assertEqual(model.generation_config.use_cache, True)
264
+ self.assertEqual(model.generation_config.cache_implementation, cache_implementation)
265
+ self.assertEqual(model.generation_config.max_length, max_cache_len)
266
+ self.assertTrue(model.generation_config.cache_config is not None)
267
+ self.assertEqual(model.generation_config.cache_config.batch_size, batch_size)
268
+ self.assertEqual(model.generation_config.cache_config.max_cache_len, max_cache_len)
269
+
270
+ exported_program = convert_and_export_with_cache(model)
271
+
272
+ # Check if the exported model is configured with the `StaticCache` correctly
273
+ n_static_key_caches = n_static_value_caches = 0
274
+ for buffer_name, buffer in exported_program.named_buffers():
275
+ if buffer_name.startswith("key_cache"):
276
+ self.assertTrue(buffer.shape[0] == batch_size)
277
+ self.assertTrue(buffer.shape[2] == max_cache_len)
278
+ n_static_key_caches = n_static_key_caches + 1
279
+ if buffer_name.startswith("value_cache"):
280
+ self.assertTrue(buffer.shape[0] == batch_size)
281
+ self.assertTrue(buffer.shape[2] == max_cache_len)
282
+ n_static_value_caches = n_static_value_caches + 1
283
+ self.assertEqual(n_static_key_caches, model.config.num_hidden_layers)
284
+ self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)
285
+
286
+
287
+ @require_torch_accelerator
288
+ @slow
289
+ class CacheIntegrationTest(unittest.TestCase):
290
+ def test_dynamic_cache_hard(self):
291
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
292
+ model = AutoModelForCausalLM.from_pretrained(
293
+ "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
294
+ )
295
+ inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
296
+
297
+ # DynamicCache and the legacy cache format should be equivalent
298
+ set_seed(0)
299
+ gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
300
+ set_seed(0)
301
+ gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
302
+ self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
303
+
304
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
305
+ expected_text = (
306
+ "Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
307
+ "to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
308
+ "Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
309
+ "what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
310
+ "are also very curious. They like to explore, and they like to play. They are also very fast. They can "
311
+ "run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
312
+ "can solve problems. They are also very playful. They like to play with toys, and they like to play with "
313
+ "other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
314
+ "also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
315
+ "clean their litter box.\nCats are also very independent. They don't"
316
+ )
317
+ self.assertEqual(decoded[0], expected_text)
318
+
319
+ def test_dynamic_cache_batched(self):
320
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
321
+ tokenizer.pad_token = tokenizer.eos_token
322
+ model = AutoModelForCausalLM.from_pretrained(
323
+ "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
324
+ )
325
+ inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
326
+ model.device
327
+ )
328
+
329
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
330
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
331
+ expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
332
+ self.assertListEqual(decoded, expected_text)
333
+
334
+ def test_dynamic_cache_beam_search(self):
335
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
336
+ model = AutoModelForCausalLM.from_pretrained(
337
+ "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
338
+ )
339
+
340
+ inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
341
+ gen_out = model.generate(
342
+ **inputs,
343
+ do_sample=False,
344
+ max_new_tokens=20,
345
+ num_beams=2,
346
+ num_return_sequences=2,
347
+ )
348
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
349
+ expected_text = [
350
+ "The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
351
+ "The best color is the one that suits you.\nThe best color is the one that suits you. The",
352
+ ]
353
+ self.assertListEqual(decoded, expected_text)
354
+
355
+ def test_hybrid_cache_n_sequences(self):
356
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
357
+ model = AutoModelForCausalLM.from_pretrained(
358
+ "google/gemma-2-9b",
359
+ device_map="auto",
360
+ torch_dtype=torch.bfloat16,
361
+ attn_implementation="eager",
362
+ )
363
+
364
+ inputs = tokenizer(["Hello I am doing"], return_tensors="pt").to(model.device)
365
+
366
+ gen_out = model.generate(
367
+ **inputs,
368
+ do_sample=False,
369
+ max_new_tokens=20,
370
+ num_return_sequences=2,
371
+ num_beams=2,
372
+ )
373
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
374
+ expected_text = [
375
+ "Hello I am doing a project for my school and I am trying to make a program that will allow me to input a",
376
+ "Hello I am doing a project for my school and I am trying to make a program that will allow me to use a",
377
+ ]
378
+ self.assertListEqual(decoded, expected_text)
379
+
380
+ @require_non_xpu
381
+ @require_gptq
382
+ def test_sink_cache_hard(self):
383
+ tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
384
+ model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
385
+
386
+ inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
387
+
388
+ # Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
389
+ # without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
390
+ cache = SinkCache(window_length=508, num_sink_tokens=4)
391
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
392
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
393
+ self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
394
+
395
+ def test_sink_cache_iterative_prompts(self):
396
+ """Tests that SinkCache supports more than one new token at once, when shifting the cache"""
397
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
398
+ model = AutoModelForCausalLM.from_pretrained(
399
+ "HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
400
+ )
401
+ prompt = (
402
+ "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
403
+ "and must-see attractions."
404
+ )
405
+
406
+ # Prepare generation settings
407
+ cache = SinkCache(window_length=256, num_sink_tokens=4)
408
+ input_ids = torch.tensor([], device=model.device, dtype=torch.int)
409
+ for _ in range(3):
410
+ # Tokenize the prompt with the correct chat template
411
+ chat = [{"role": "user", "content": prompt}]
412
+ tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
413
+ model.device
414
+ )
415
+ input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
416
+
417
+ # Perform the generation
418
+ gen_out = model.generate(
419
+ input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
420
+ )
421
+ input_ids = gen_out
422
+
423
+ # We went well beyond the cache length
424
+ self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
425
+
426
+ # And it still produces a coherent english
427
+ decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
428
+ last_output = (
429
+ "<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
430
+ "Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
431
+ "beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
432
+ "and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
433
+ "was visiting the historic district of Honolulu. Here,"
434
+ )
435
+ self.assertTrue(decoded[0].endswith(last_output))
436
+
437
+ @require_torch_gpu
438
+ @parameterized.expand(
439
+ [
440
+ ("eager", "static"),
441
+ ("sdpa", "static"),
442
+ ]
443
+ )
444
+ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation):
445
+ EXPECTED_GENERATION = [
446
+ "The best color is the one that complements the skin tone of the",
447
+ "We should not undermind the issues at hand.\nWe should not undermind the issues",
448
+ ]
449
+
450
+ tokenizer = AutoTokenizer.from_pretrained(
451
+ "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
452
+ )
453
+ model = AutoModelForCausalLM.from_pretrained(
454
+ "NousResearch/Llama-2-7b-chat-hf",
455
+ torch_dtype=torch.bfloat16,
456
+ attn_implementation=attn_implementation,
457
+ ).to(torch_device)
458
+ inputs = tokenizer(
459
+ ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
460
+ ).to(model.device)
461
+
462
+ set_seed(0)
463
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
464
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
465
+ with self.subTest(f"{attn_implementation}, dynamic"):
466
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
467
+
468
+ set_seed(0)
469
+ model.generation_config.cache_implementation = cache_implementation
470
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
471
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
472
+ with self.subTest(f"{attn_implementation}, static, eager"):
473
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
474
+
475
+ set_seed(0)
476
+ model.forward = torch.compile(model.forward)
477
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
478
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
479
+ with self.subTest(f"{attn_implementation}, static, compiled"):
480
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
481
+
482
+ @require_torch_gpu
483
+ @parameterized.expand(
484
+ [
485
+ ("eager", "static"),
486
+ ("sdpa", "static"),
487
+ ]
488
+ )
489
+ def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation):
490
+ EXPECTED_GENERATION = [
491
+ "The best color isЋ the one that complements the skin tone of",
492
+ "We should not undermind the issues at hand.\nWe should not undermind the issues",
493
+ ]
494
+
495
+ tokenizer = AutoTokenizer.from_pretrained(
496
+ "NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
497
+ )
498
+ model = AutoModelForCausalLM.from_pretrained(
499
+ "NousResearch/Llama-2-7b-chat-hf",
500
+ torch_dtype=torch.bfloat16,
501
+ attn_implementation=attn_implementation,
502
+ ).to(torch_device)
503
+ inputs = tokenizer(
504
+ ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
505
+ ).to(model.device)
506
+
507
+ set_seed(0)
508
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
509
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
510
+ with self.subTest(f"{attn_implementation}, dynamic"):
511
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
512
+
513
+ set_seed(0)
514
+ model.generation_config.cache_implementation = cache_implementation
515
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
516
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
517
+ with self.subTest(f"{attn_implementation}, static, eager"):
518
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
519
+
520
+ def test_dynamic_cache_extra_left_padding(self):
521
+ """Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
522
+ EXPECTED_GENERATION = [
523
+ "The best color is the one that complements the skin tone of the",
524
+ "We should not undermind the issues at hand.\nWe should not undermind the issues",
525
+ ]
526
+
527
+ tokenizer = AutoTokenizer.from_pretrained(
528
+ "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
529
+ )
530
+ model = AutoModelForCausalLM.from_pretrained(
531
+ "NousResearch/Llama-2-7b-chat-hf",
532
+ torch_dtype=torch.bfloat16,
533
+ ).to(torch_device)
534
+ inputs = tokenizer(
535
+ ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
536
+ ).to(model.device)
537
+
538
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
539
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
540
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
541
+
542
+ # Now with extra left-padding
543
+ inputs_expanded = tokenizer(
544
+ ["The best color is", "We should not undermind the issues at hand"],
545
+ padding=True,
546
+ return_tensors="pt",
547
+ pad_to_multiple_of=32,
548
+ ).to(model.device)
549
+ self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
550
+ gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
551
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
552
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
553
+
554
+ @parameterized.expand(
555
+ [
556
+ "static",
557
+ ]
558
+ )
559
+ def test_static_cache_extra_left_padding(self, cache_implementation):
560
+ """Tests that adding extra left-padding does not affect the generation with the static cache"""
561
+ EXPECTED_GENERATION = [
562
+ "The best color is the one that complements the skin tone of the",
563
+ "We should not undermind the issues at hand.\nWe should not undermind the issues",
564
+ ]
565
+
566
+ tokenizer = AutoTokenizer.from_pretrained(
567
+ "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
568
+ )
569
+ model = AutoModelForCausalLM.from_pretrained(
570
+ "NousResearch/Llama-2-7b-chat-hf",
571
+ torch_dtype=torch.bfloat16,
572
+ ).to(torch_device)
573
+ inputs = tokenizer(
574
+ ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
575
+ ).to(model.device)
576
+
577
+ model.generation_config.cache_implementation = cache_implementation
578
+
579
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
580
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
581
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
582
+
583
+ # Now with extra left-padding
584
+ inputs_expanded = tokenizer(
585
+ ["The best color is", "We should not undermind the issues at hand"],
586
+ padding=True,
587
+ return_tensors="pt",
588
+ pad_to_multiple_of=32,
589
+ ).to(model.device)
590
+ self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
591
+ gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
592
+ decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
593
+ self.assertListEqual(decoded, EXPECTED_GENERATION)
594
+
595
+ @unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
596
+ def test_static_cache_beam_search(self):
597
+ pass
598
+
599
+ @require_torch_accelerator
600
+ def test_offloaded_cache_equivalent_to_dynamic_cache(self):
601
+ """Tests that OffloadedCache produces the same result as the default DynamicCache"""
602
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
603
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
604
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
605
+ device = model.device
606
+
607
+ if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
608
+ self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
609
+
610
+ input_text = "Fun fact:"
611
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
612
+ common = {
613
+ "num_beams": 4,
614
+ "num_beam_groups": 2,
615
+ "num_return_sequences": 4,
616
+ "diversity_penalty": 1.0,
617
+ "max_new_tokens": 20,
618
+ "early_stopping": True,
619
+ }
620
+ original = GenerationConfig(**common)
621
+ offloaded = GenerationConfig(cache_implementation="offloaded", **common)
622
+ original_outputs = model.generate(generation_config=original, **inputs)
623
+ offloaded_outputs = model.generate(generation_config=offloaded, **inputs)
624
+ for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
625
+ assert torch.all(original_output == offloaded_output).item()
626
+
627
+ @require_torch_accelerator
628
+ def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
629
+ """Tests that OffloadedCache uses less memory than the default DynamicCache"""
630
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
631
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
632
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
633
+ device = model.device
634
+
635
+ if not is_torch_greater_or_equal("2.7", accept_dev=True) and device.type == "xpu":
636
+ self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")
637
+
638
+ input_text = "Fun fact:"
639
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
640
+ common = {
641
+ "num_beams": 4,
642
+ "num_beam_groups": 2,
643
+ "num_return_sequences": 4,
644
+ "diversity_penalty": 1.0,
645
+ "max_new_tokens": 20,
646
+ "early_stopping": True,
647
+ }
648
+ original = GenerationConfig(**common)
649
+ offloaded = GenerationConfig(cache_implementation="offloaded", **common)
650
+
651
+ torch_accelerator_module = None
652
+ if device.type == "cuda":
653
+ torch_accelerator_module = torch.cuda
654
+ elif device.type == "xpu":
655
+ torch_accelerator_module = torch.xpu
656
+
657
+ torch_accelerator_module.reset_peak_memory_stats(device)
658
+ model.generate(generation_config=original, **inputs)
659
+ original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
660
+ torch_accelerator_module.reset_peak_memory_stats(device)
661
+ model.generate(generation_config=offloaded, **inputs)
662
+ offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
663
+ print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
664
+ assert offloaded_peak_memory < original_peak_memory
665
+
666
+ @require_torch_gpu
667
+ def test_cache_copy(self):
668
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
669
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
670
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
671
+
672
+ prompt_cache = StaticCache(
673
+ config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16
674
+ )
675
+
676
+ INITIAL_PROMPT = "You are a helpful assistant. "
677
+ inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
678
+ # This is the common prompt cached, we need to run forward without grad to be abel to copy
679
+ with torch.no_grad():
680
+ prompt_cache = model(**inputs_initial_prompt, past_key_values=prompt_cache).past_key_values
681
+
682
+ prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
683
+ responses = []
684
+ for prompt in prompts:
685
+ new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
686
+ past_key_values = copy.deepcopy(prompt_cache)
687
+ outputs = model.generate(**new_inputs, past_key_values=past_key_values, max_new_tokens=40)
688
+ response = tokenizer.batch_decode(outputs)[0]
689
+ responses.append(response)
690
+
691
+ EXPECTED_DECODED_TEXT = [
692
+ "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
693
+ 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
694
+ ] # fmt: skip
695
+ self.assertEqual(responses, EXPECTED_DECODED_TEXT)
696
+
697
+ @require_torch_multi_gpu
698
+ def test_data_parallel_dynamic_cache(self):
699
+ """
700
+ Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
701
+ multiple `DynamicCache` in the gather step.
702
+ """
703
+
704
+ model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
705
+ model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
706
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
707
+
708
+ # w/o DP: batch_size = num_gpu
709
+ # w DP: batch_size = 1 (with num_gpus replicas)
710
+ num_gpus = get_gpu_count()
711
+ model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)
712
+
713
+ # w/o DP
714
+ no_parallelism_cache = model(**model_inputs).past_key_values
715
+ self.assertIsInstance(no_parallelism_cache, DynamicCache)
716
+
717
+ # w DP
718
+ model = torch.nn.DataParallel(model)
719
+ parallelism_cache = model(**model_inputs).past_key_values
720
+ self.assertIsInstance(parallelism_cache, DynamicCache)
721
+
722
+ # Check that the caches are the same
723
+ for layer_idx in range(len(no_parallelism_cache)):
724
+ for kv_idx in range(2): # 0 = key, 1 = value
725
+ torch.testing.assert_close(
726
+ actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
727
+ )
728
+
729
+ @require_torch_gpu
730
+ def test_static_cache_no_cuda_graph_skips(self):
731
+ """
732
+ Tests generating with static cache and compilation doesn't skip cuda graphs. Regression test for #36543.
733
+
734
+ (? We set `fullgraph=True`, which according to torch docs means it should raise an exception. Instead,
735
+ messages are being thrown to stderr?)
736
+ """
737
+ model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
738
+ model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
739
+ tokenizer = AutoTokenizer.from_pretrained(model_repo)
740
+ inputs = tokenizer(["foo bar"], return_tensors="pt").to(torch_device)
741
+
742
+ # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
743
+ with CaptureStderr() as cap:
744
+ model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
745
+ self.assertEqual(cap.err, "")
746
+
747
+ @require_torch_multi_gpu
748
+ def test_static_cache_multi_gpu(self):
749
+ """Regression test for #35164: static cache with multi-gpu"""
750
+
751
+ model_id = "google/gemma-2-2b-it"
752
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
753
+
754
+ device_map = {"model.embed_tokens": 0, "model.norm": 1, "model.rotary_emb": 1, "lm_head": 0}
755
+ num_hidden_layers = 26
756
+ for i in range(num_hidden_layers):
757
+ device_map[f"model.layers.{i}"] = 0 if i < 13 else 1
758
+
759
+ model = AutoModelForCausalLM.from_pretrained(
760
+ model_id,
761
+ torch_dtype="bfloat16",
762
+ device_map=device_map,
763
+ )
764
+ inputs = tokenizer("Today is a beautiful day!", return_tensors="pt").to(0)
765
+ _ = model(**inputs)
766
+ _ = model.generate(**inputs, max_new_tokens=2, cache_implementation="hybrid")
docs/transformers/tests/utils/test_chat_template_utils.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from typing import Optional, Union
17
+
18
+ from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
19
+
20
+
21
+ class JsonSchemaGeneratorTest(unittest.TestCase):
22
+ def test_simple_function(self):
23
+ def fn(x: int):
24
+ """
25
+ Test function
26
+
27
+ Args:
28
+ x: The input
29
+ """
30
+ return x
31
+
32
+ schema = get_json_schema(fn)
33
+ expected_schema = {
34
+ "name": "fn",
35
+ "description": "Test function",
36
+ "parameters": {
37
+ "type": "object",
38
+ "properties": {"x": {"type": "integer", "description": "The input"}},
39
+ "required": ["x"],
40
+ },
41
+ }
42
+ self.assertEqual(schema["function"], expected_schema)
43
+
44
+ def test_no_arguments(self):
45
+ def fn():
46
+ """
47
+ Test function
48
+ """
49
+ return True
50
+
51
+ schema = get_json_schema(fn)
52
+ expected_schema = {
53
+ "name": "fn",
54
+ "description": "Test function",
55
+ "parameters": {"type": "object", "properties": {}},
56
+ }
57
+ self.assertEqual(schema["function"], expected_schema)
58
+
59
+ def test_union(self):
60
+ def fn(x: Union[int, float]):
61
+ """
62
+ Test function
63
+
64
+ Args:
65
+ x: The input
66
+ """
67
+ return x
68
+
69
+ schema = get_json_schema(fn)
70
+ expected_schema = {
71
+ "name": "fn",
72
+ "description": "Test function",
73
+ "parameters": {
74
+ "type": "object",
75
+ "properties": {"x": {"type": ["integer", "number"], "description": "The input"}},
76
+ "required": ["x"],
77
+ },
78
+ }
79
+ self.assertEqual(schema["function"], expected_schema)
80
+
81
+ def test_optional(self):
82
+ def fn(x: Optional[int]):
83
+ """
84
+ Test function
85
+
86
+ Args:
87
+ x: The input
88
+ """
89
+ return x
90
+
91
+ schema = get_json_schema(fn)
92
+ expected_schema = {
93
+ "name": "fn",
94
+ "description": "Test function",
95
+ "parameters": {
96
+ "type": "object",
97
+ "properties": {"x": {"type": "integer", "description": "The input", "nullable": True}},
98
+ "required": ["x"],
99
+ },
100
+ }
101
+ self.assertEqual(schema["function"], expected_schema)
102
+
103
+ def test_default_arg(self):
104
+ def fn(x: int = 42):
105
+ """
106
+ Test function
107
+
108
+ Args:
109
+ x: The input
110
+ """
111
+ return x
112
+
113
+ schema = get_json_schema(fn)
114
+ expected_schema = {
115
+ "name": "fn",
116
+ "description": "Test function",
117
+ "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "The input"}}},
118
+ }
119
+ self.assertEqual(schema["function"], expected_schema)
120
+
121
+ def test_nested_list(self):
122
+ def fn(x: list[list[Union[str, int]]]):
123
+ """
124
+ Test function
125
+
126
+ Args:
127
+ x: The input
128
+ """
129
+ return x
130
+
131
+ schema = get_json_schema(fn)
132
+ expected_schema = {
133
+ "name": "fn",
134
+ "description": "Test function",
135
+ "parameters": {
136
+ "type": "object",
137
+ "properties": {
138
+ "x": {
139
+ "type": "array",
140
+ "items": {"type": "array", "items": {"type": ["integer", "string"]}},
141
+ "description": "The input",
142
+ }
143
+ },
144
+ "required": ["x"],
145
+ },
146
+ }
147
+ self.assertEqual(schema["function"], expected_schema)
148
+
149
+ def test_multiple_arguments(self):
150
+ def fn(x: int, y: str):
151
+ """
152
+ Test function
153
+
154
+ Args:
155
+ x: The input
156
+ y: Also the input
157
+ """
158
+ return x
159
+
160
+ schema = get_json_schema(fn)
161
+ expected_schema = {
162
+ "name": "fn",
163
+ "description": "Test function",
164
+ "parameters": {
165
+ "type": "object",
166
+ "properties": {
167
+ "x": {"type": "integer", "description": "The input"},
168
+ "y": {"type": "string", "description": "Also the input"},
169
+ },
170
+ "required": ["x", "y"],
171
+ },
172
+ }
173
+ self.assertEqual(schema["function"], expected_schema)
174
+
175
+ def test_multiple_complex_arguments(self):
176
+ def fn(x: list[Union[int, float]], y: Optional[Union[int, str]] = None):
177
+ """
178
+ Test function
179
+
180
+ Args:
181
+ x: The input
182
+ y: Also the input
183
+ """
184
+ return x
185
+
186
+ schema = get_json_schema(fn)
187
+ expected_schema = {
188
+ "name": "fn",
189
+ "description": "Test function",
190
+ "parameters": {
191
+ "type": "object",
192
+ "properties": {
193
+ "x": {"type": "array", "items": {"type": ["integer", "number"]}, "description": "The input"},
194
+ "y": {
195
+ "type": ["integer", "string"],
196
+ "nullable": True,
197
+ "description": "Also the input",
198
+ },
199
+ },
200
+ "required": ["x"],
201
+ },
202
+ }
203
+ self.assertEqual(schema["function"], expected_schema)
204
+
205
+ def test_missing_docstring(self):
206
+ def fn(x: int):
207
+ return x
208
+
209
+ with self.assertRaises(DocstringParsingException):
210
+ get_json_schema(fn)
211
+
212
+ def test_missing_param_docstring(self):
213
+ def fn(x: int):
214
+ """
215
+ Test function
216
+ """
217
+ return x
218
+
219
+ with self.assertRaises(DocstringParsingException):
220
+ get_json_schema(fn)
221
+
222
+ def test_missing_type_hint(self):
223
+ def fn(x):
224
+ """
225
+ Test function
226
+
227
+ Args:
228
+ x: The input
229
+ """
230
+ return x
231
+
232
+ with self.assertRaises(TypeHintParsingException):
233
+ get_json_schema(fn)
234
+
235
+ def test_return_value(self):
236
+ def fn(x: int) -> int:
237
+ """
238
+ Test function
239
+
240
+ Args:
241
+ x: The input
242
+ """
243
+ return x
244
+
245
+ schema = get_json_schema(fn)
246
+ expected_schema = {
247
+ "name": "fn",
248
+ "description": "Test function",
249
+ "parameters": {
250
+ "type": "object",
251
+ "properties": {"x": {"type": "integer", "description": "The input"}},
252
+ "required": ["x"],
253
+ },
254
+ "return": {"type": "integer"},
255
+ }
256
+ self.assertEqual(schema["function"], expected_schema)
257
+
258
+ def test_return_value_docstring(self):
259
+ def fn(x: int) -> int:
260
+ """
261
+ Test function
262
+
263
+ Args:
264
+ x: The input
265
+
266
+
267
+ Returns:
268
+ The output
269
+ """
270
+ return x
271
+
272
+ schema = get_json_schema(fn)
273
+ expected_schema = {
274
+ "name": "fn",
275
+ "description": "Test function",
276
+ "parameters": {
277
+ "type": "object",
278
+ "properties": {"x": {"type": "integer", "description": "The input"}},
279
+ "required": ["x"],
280
+ },
281
+ "return": {"type": "integer", "description": "The output"},
282
+ }
283
+ self.assertEqual(schema["function"], expected_schema)
284
+
285
+ def test_tuple(self):
286
+ def fn(x: tuple[int, str]):
287
+ """
288
+ Test function
289
+
290
+ Args:
291
+ x: The input
292
+
293
+
294
+ Returns:
295
+ The output
296
+ """
297
+ return x
298
+
299
+ schema = get_json_schema(fn)
300
+ expected_schema = {
301
+ "name": "fn",
302
+ "description": "Test function",
303
+ "parameters": {
304
+ "type": "object",
305
+ "properties": {
306
+ "x": {
307
+ "type": "array",
308
+ "prefixItems": [{"type": "integer"}, {"type": "string"}],
309
+ "description": "The input",
310
+ }
311
+ },
312
+ "required": ["x"],
313
+ },
314
+ }
315
+ self.assertEqual(schema["function"], expected_schema)
316
+
317
+ def test_single_element_tuple_fails(self):
318
+ def fn(x: tuple[int]):
319
+ """
320
+ Test function
321
+
322
+ Args:
323
+ x: The input
324
+
325
+
326
+ Returns:
327
+ The output
328
+ """
329
+ return x
330
+
331
+ # Single-element tuples should just be the type itself, or List[type] for variable-length inputs
332
+ with self.assertRaises(TypeHintParsingException):
333
+ get_json_schema(fn)
334
+
335
+ def test_ellipsis_type_fails(self):
336
+ def fn(x: tuple[int, ...]):
337
+ """
338
+ Test function
339
+
340
+ Args:
341
+ x: The input
342
+
343
+
344
+ Returns:
345
+ The output
346
+ """
347
+ return x
348
+
349
+ # Variable length inputs should be specified with List[type], not Tuple[type, ...]
350
+ with self.assertRaises(TypeHintParsingException):
351
+ get_json_schema(fn)
352
+
353
+ def test_enum_extraction(self):
354
+ def fn(temperature_format: str):
355
+ """
356
+ Test function
357
+
358
+ Args:
359
+ temperature_format: The temperature format to use (Choices: ["celsius", "fahrenheit"])
360
+
361
+
362
+ Returns:
363
+ The temperature
364
+ """
365
+ return -40.0
366
+
367
+ # Let's see if that gets correctly parsed as an enum
368
+ schema = get_json_schema(fn)
369
+ expected_schema = {
370
+ "name": "fn",
371
+ "description": "Test function",
372
+ "parameters": {
373
+ "type": "object",
374
+ "properties": {
375
+ "temperature_format": {
376
+ "type": "string",
377
+ "enum": ["celsius", "fahrenheit"],
378
+ "description": "The temperature format to use",
379
+ }
380
+ },
381
+ "required": ["temperature_format"],
382
+ },
383
+ }
384
+
385
+ self.assertEqual(schema["function"], expected_schema)
386
+
387
+ def test_multiline_docstring_with_types(self):
388
+ def fn(x: int, y: int):
389
+ """
390
+ Test function
391
+
392
+ Args:
393
+ x: The first input
394
+
395
+ y: The second input. This is a longer description
396
+ that spans multiple lines with indentation and stuff.
397
+
398
+ Returns:
399
+ God knows what
400
+ """
401
+ pass
402
+
403
+ schema = get_json_schema(fn)
404
+ expected_schema = {
405
+ "name": "fn",
406
+ "description": "Test function",
407
+ "parameters": {
408
+ "type": "object",
409
+ "properties": {
410
+ "x": {"type": "integer", "description": "The first input"},
411
+ "y": {
412
+ "type": "integer",
413
+ "description": "The second input. This is a longer description that spans multiple lines with indentation and stuff.",
414
+ },
415
+ },
416
+ "required": ["x", "y"],
417
+ },
418
+ }
419
+
420
+ self.assertEqual(schema["function"], expected_schema)
421
+
422
+ def test_return_none(self):
423
+ def fn(x: int) -> None:
424
+ """
425
+ Test function
426
+
427
+ Args:
428
+ x: The first input
429
+ """
430
+ pass
431
+
432
+ schema = get_json_schema(fn)
433
+ expected_schema = {
434
+ "name": "fn",
435
+ "description": "Test function",
436
+ "parameters": {
437
+ "type": "object",
438
+ "properties": {
439
+ "x": {"type": "integer", "description": "The first input"},
440
+ },
441
+ "required": ["x"],
442
+ },
443
+ "return": {"type": "null"},
444
+ }
445
+ self.assertEqual(schema["function"], expected_schema)
446
+
447
+ def test_everything_all_at_once(self):
448
+ def fn(
449
+ x: str, y: Optional[list[Union[str, int]]], z: tuple[Union[str, int], str] = (42, "hello")
450
+ ) -> tuple[int, str]:
451
+ """
452
+ Test function with multiple args, and docstring args that we have to strip out.
453
+
454
+ Args:
455
+ x: The first input. It's got a big multiline
456
+ description and also contains
457
+ (choices: ["a", "b", "c"])
458
+
459
+ y: The second input. It's a big list with a single-line description.
460
+
461
+ z: The third input. It's some kind of tuple with a default arg.
462
+
463
+ Returns:
464
+ The output. The return description is also a big multiline
465
+ description that spans multiple lines.
466
+ """
467
+ pass
468
+
469
+ schema = get_json_schema(fn)
470
+ expected_schema = {
471
+ "name": "fn",
472
+ "description": "Test function with multiple args, and docstring args that we have to strip out.",
473
+ "parameters": {
474
+ "type": "object",
475
+ "properties": {
476
+ "x": {
477
+ "type": "string",
478
+ "enum": ["a", "b", "c"],
479
+ "description": "The first input. It's got a big multiline description and also contains",
480
+ },
481
+ "y": {
482
+ "type": "array",
483
+ "items": {"type": ["integer", "string"]},
484
+ "nullable": True,
485
+ "description": "The second input. It's a big list with a single-line description.",
486
+ },
487
+ "z": {
488
+ "type": "array",
489
+ "prefixItems": [{"type": ["integer", "string"]}, {"type": "string"}],
490
+ "description": "The third input. It's some kind of tuple with a default arg.",
491
+ },
492
+ },
493
+ "required": ["x", "y"],
494
+ },
495
+ "return": {
496
+ "type": "array",
497
+ "prefixItems": [{"type": "integer"}, {"type": "string"}],
498
+ "description": "The output. The return description is also a big multiline\n description that spans multiple lines.",
499
+ },
500
+ }
501
+ self.assertEqual(schema["function"], expected_schema)
docs/transformers/tests/utils/test_cli.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import shutil
17
+ import unittest
18
+ from unittest.mock import patch
19
+
20
+ from transformers.testing_utils import CaptureStd, require_torch
21
+
22
+
23
+ class CLITest(unittest.TestCase):
24
+ @patch("sys.argv", ["fakeprogrampath", "env"])
25
+ def test_cli_env(self):
26
+ # test transformers-cli env
27
+ import transformers.commands.transformers_cli
28
+
29
+ with CaptureStd() as cs:
30
+ transformers.commands.transformers_cli.main()
31
+ self.assertIn("Python version", cs.out)
32
+ self.assertIn("Platform", cs.out)
33
+ self.assertIn("Using distributed or parallel set-up in script?", cs.out)
34
+
35
+ @require_torch
36
+ @patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"])
37
+ def test_cli_download(self):
38
+ import transformers.commands.transformers_cli
39
+
40
+ # # remove any previously downloaded model to start clean
41
+ shutil.rmtree("/tmp/models--hf-internal-testing--tiny-random-gptj", ignore_errors=True)
42
+
43
+ # run the command
44
+ transformers.commands.transformers_cli.main()
45
+
46
+ # check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--tiny-random-gptj
47
+ self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/blobs"))
48
+ self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/refs"))
49
+ self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/snapshots"))
50
+
51
+ @require_torch
52
+ @patch(
53
+ "sys.argv",
54
+ [
55
+ "fakeprogrampath",
56
+ "download",
57
+ "hf-internal-testing/test_dynamic_model_with_tokenizer",
58
+ "--trust-remote-code",
59
+ "--cache-dir",
60
+ "/tmp",
61
+ ],
62
+ )
63
+ def test_cli_download_trust_remote(self):
64
+ import transformers.commands.transformers_cli
65
+
66
+ # # remove any previously downloaded model to start clean
67
+ shutil.rmtree("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer", ignore_errors=True)
68
+
69
+ # run the command
70
+ transformers.commands.transformers_cli.main()
71
+
72
+ # check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer
73
+ self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/blobs"))
74
+ self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/refs"))
75
+ self.assertTrue(
76
+ os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/snapshots")
77
+ )
docs/transformers/tests/utils/test_configuration_utils.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import tempfile
19
+ import unittest
20
+ import unittest.mock as mock
21
+ import warnings
22
+ from pathlib import Path
23
+
24
+ from huggingface_hub import HfFolder
25
+ from requests.exceptions import HTTPError
26
+
27
+ from transformers import AutoConfig, BertConfig, GPT2Config
28
+ from transformers.configuration_utils import PretrainedConfig
29
+ from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test
30
+
31
+
32
+ sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
33
+
34
+ from test_module.custom_configuration import CustomConfig # noqa E402
35
+
36
+
37
+ config_common_kwargs = {
38
+ "return_dict": False,
39
+ "output_hidden_states": True,
40
+ "output_attentions": True,
41
+ "torchscript": True,
42
+ "torch_dtype": "float16",
43
+ "use_bfloat16": True,
44
+ "tf_legacy_loss": True,
45
+ "pruned_heads": {"a": 1},
46
+ "tie_word_embeddings": False,
47
+ "is_decoder": True,
48
+ "cross_attention_hidden_size": 128,
49
+ "add_cross_attention": True,
50
+ "tie_encoder_decoder": True,
51
+ "max_length": 50,
52
+ "min_length": 3,
53
+ "do_sample": True,
54
+ "early_stopping": True,
55
+ "num_beams": 3,
56
+ "num_beam_groups": 3,
57
+ "diversity_penalty": 0.5,
58
+ "temperature": 2.0,
59
+ "top_k": 10,
60
+ "top_p": 0.7,
61
+ "typical_p": 0.2,
62
+ "repetition_penalty": 0.8,
63
+ "length_penalty": 0.8,
64
+ "no_repeat_ngram_size": 5,
65
+ "encoder_no_repeat_ngram_size": 5,
66
+ "bad_words_ids": [1, 2, 3],
67
+ "num_return_sequences": 3,
68
+ "chunk_size_feed_forward": 5,
69
+ "output_scores": True,
70
+ "return_dict_in_generate": True,
71
+ "forced_bos_token_id": 2,
72
+ "forced_eos_token_id": 3,
73
+ "remove_invalid_values": True,
74
+ "architectures": ["BertModel"],
75
+ "finetuning_task": "translation",
76
+ "id2label": {0: "label"},
77
+ "label2id": {"label": "0"},
78
+ "tokenizer_class": "BertTokenizerFast",
79
+ "prefix": "prefix",
80
+ "bos_token_id": 6,
81
+ "pad_token_id": 7,
82
+ "eos_token_id": 8,
83
+ "sep_token_id": 9,
84
+ "decoder_start_token_id": 10,
85
+ "exponential_decay_length_penalty": (5, 1.01),
86
+ "suppress_tokens": [0, 1],
87
+ "begin_suppress_tokens": 2,
88
+ "task_specific_params": {"translation": "some_params"},
89
+ "problem_type": "regression",
90
+ }
91
+
92
+
93
+ @is_staging_test
94
+ class ConfigPushToHubTester(unittest.TestCase):
95
+ @classmethod
96
+ def setUpClass(cls):
97
+ cls._token = TOKEN
98
+ HfFolder.save_token(TOKEN)
99
+
100
+ def test_push_to_hub(self):
101
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
102
+ config = BertConfig(
103
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
104
+ )
105
+ config.push_to_hub(tmp_repo.repo_id, token=self._token)
106
+
107
+ new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
108
+ for k, v in config.to_dict().items():
109
+ if k != "transformers_version":
110
+ self.assertEqual(v, getattr(new_config, k))
111
+
112
+ def test_push_to_hub_via_save_pretrained(self):
113
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
114
+ config = BertConfig(
115
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
116
+ )
117
+ # Push to hub via save_pretrained
118
+ with tempfile.TemporaryDirectory() as tmp_dir:
119
+ config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
120
+
121
+ new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
122
+ for k, v in config.to_dict().items():
123
+ if k != "transformers_version":
124
+ self.assertEqual(v, getattr(new_config, k))
125
+
126
+ def test_push_to_hub_in_organization(self):
127
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
128
+ config = BertConfig(
129
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
130
+ )
131
+ config.push_to_hub(tmp_repo.repo_id, token=self._token)
132
+
133
+ new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
134
+ for k, v in config.to_dict().items():
135
+ if k != "transformers_version":
136
+ self.assertEqual(v, getattr(new_config, k))
137
+
138
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
139
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
140
+ config = BertConfig(
141
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
142
+ )
143
+ # Push to hub via save_pretrained
144
+ with tempfile.TemporaryDirectory() as tmp_dir:
145
+ config.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
146
+
147
+ new_config = BertConfig.from_pretrained(tmp_repo.repo_id)
148
+ for k, v in config.to_dict().items():
149
+ if k != "transformers_version":
150
+ self.assertEqual(v, getattr(new_config, k))
151
+
152
+ def test_push_to_hub_dynamic_config(self):
153
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
154
+ CustomConfig.register_for_auto_class()
155
+ config = CustomConfig(attribute=42)
156
+
157
+ config.push_to_hub(tmp_repo.repo_id, token=self._token)
158
+
159
+ # This has added the proper auto_map field to the config
160
+ self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
161
+
162
+ new_config = AutoConfig.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
163
+ # Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
164
+ self.assertEqual(new_config.__class__.__name__, "CustomConfig")
165
+ self.assertEqual(new_config.attribute, 42)
166
+
167
+
168
+ class ConfigTestUtils(unittest.TestCase):
169
+ def test_config_from_string(self):
170
+ c = GPT2Config()
171
+
172
+ # attempt to modify each of int/float/bool/str config records and verify they were updated
173
+ n_embd = c.n_embd + 1 # int
174
+ resid_pdrop = c.resid_pdrop + 1.0 # float
175
+ scale_attn_weights = not c.scale_attn_weights # bool
176
+ summary_type = c.summary_type + "foo" # str
177
+ c.update_from_string(
178
+ f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
179
+ )
180
+ self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
181
+ self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
182
+ self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
183
+ self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
184
+
185
+ def test_config_common_kwargs_is_complete(self):
186
+ base_config = PretrainedConfig()
187
+ missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
188
+ # If this part of the test fails, you have arguments to addin config_common_kwargs above.
189
+ self.assertListEqual(
190
+ missing_keys,
191
+ [
192
+ "is_encoder_decoder",
193
+ "_name_or_path",
194
+ "_commit_hash",
195
+ "_attn_implementation_internal",
196
+ "_attn_implementation_autoset",
197
+ "transformers_version",
198
+ ],
199
+ )
200
+ keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
201
+ if len(keys_with_defaults) > 0:
202
+ raise ValueError(
203
+ "The following keys are set with the default values in"
204
+ " `test_configuration_common.config_common_kwargs` pick another value for them:"
205
+ f" {', '.join(keys_with_defaults)}."
206
+ )
207
+
208
+ def test_nested_config_load_from_dict(self):
209
+ config = AutoConfig.from_pretrained(
210
+ "hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
211
+ )
212
+ self.assertNotIsInstance(config.text_config, dict)
213
+ self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
214
+
215
+ def test_from_pretrained_subfolder(self):
216
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
217
+ self.assertIsNotNone(config)
218
+
219
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
220
+ self.assertIsNotNone(config)
221
+
222
+ def test_cached_files_are_used_when_internet_is_down(self):
223
+ # A mock response for an HTTP head request to emulate server down
224
+ response_mock = mock.Mock()
225
+ response_mock.status_code = 500
226
+ response_mock.headers = {}
227
+ response_mock.raise_for_status.side_effect = HTTPError
228
+ response_mock.json.return_value = {}
229
+
230
+ # Download this model to make sure it's in the cache.
231
+ _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
232
+
233
+ # Under the mock environment we get a 500 error when trying to reach the model.
234
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
235
+ _ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
236
+ # This check we did call the fake head request
237
+ mock_head.assert_called()
238
+
239
+ def test_local_versioning(self):
240
+ configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
241
+ configuration.configuration_files = ["config.4.0.0.json"]
242
+
243
+ with tempfile.TemporaryDirectory() as tmp_dir:
244
+ configuration.save_pretrained(tmp_dir)
245
+ configuration.hidden_size = 2
246
+ json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
247
+
248
+ # This should pick the new configuration file as the version of Transformers is > 4.0.0
249
+ new_configuration = AutoConfig.from_pretrained(tmp_dir)
250
+ self.assertEqual(new_configuration.hidden_size, 2)
251
+
252
+ # Will need to be adjusted if we reach v42 and this test is still here.
253
+ # Should pick the old configuration file as the version of Transformers is < 4.42.0
254
+ configuration.configuration_files = ["config.42.0.0.json"]
255
+ configuration.hidden_size = 768
256
+ configuration.save_pretrained(tmp_dir)
257
+ shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
258
+ new_configuration = AutoConfig.from_pretrained(tmp_dir)
259
+ self.assertEqual(new_configuration.hidden_size, 768)
260
+
261
+ def test_repo_versioning_before(self):
262
+ # This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
263
+ repo = "hf-internal-testing/test-two-configs"
264
+
265
+ import transformers as new_transformers
266
+
267
+ new_transformers.configuration_utils.__version__ = "v4.0.0"
268
+ new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
269
+ repo, return_unused_kwargs=True
270
+ )
271
+ self.assertEqual(new_configuration.hidden_size, 2)
272
+ # This checks `_configuration_file` ia not kept in the kwargs by mistake.
273
+ self.assertDictEqual(kwargs, {})
274
+
275
+ # Testing an older version by monkey-patching the version in the module it's used.
276
+ import transformers as old_transformers
277
+
278
+ old_transformers.configuration_utils.__version__ = "v3.0.0"
279
+ old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
280
+ self.assertEqual(old_configuration.hidden_size, 768)
281
+
282
+ def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
283
+ config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
284
+ with tempfile.TemporaryDirectory() as tmp_dir:
285
+ with self.assertWarns(UserWarning) as cm:
286
+ config.save_pretrained(tmp_dir)
287
+ self.assertIn("min_length", str(cm.warning))
288
+
289
+ def test_get_non_default_generation_parameters(self):
290
+ config = BertConfig()
291
+ self.assertFalse(len(config._get_non_default_generation_parameters()) > 0)
292
+ config = BertConfig(min_length=3)
293
+ self.assertTrue(len(config._get_non_default_generation_parameters()) > 0)
294
+ config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
295
+ self.assertFalse(len(config._get_non_default_generation_parameters()) > 0)
296
+
297
+ def test_loading_config_do_not_raise_future_warnings(self):
298
+ """Regression test for https://github.com/huggingface/transformers/issues/31002."""
299
+ # Loading config should not raise a FutureWarning. It was the case before.
300
+ with warnings.catch_warnings():
301
+ warnings.simplefilter("error")
302
+ PretrainedConfig.from_pretrained("bert-base-uncased")
docs/transformers/tests/utils/test_convert_slow_tokenizer.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import warnings
3
+ from dataclasses import dataclass
4
+
5
+ from transformers.convert_slow_tokenizer import SpmConverter
6
+ from transformers.testing_utils import get_tests_dir
7
+
8
+
9
+ @dataclass
10
+ class FakeOriginalTokenizer:
11
+ vocab_file: str
12
+
13
+
14
+ class ConvertSlowTokenizerTest(unittest.TestCase):
15
+ def test_spm_converter_bytefallback_warning(self):
16
+ spm_model_file_without_bytefallback = get_tests_dir("fixtures/test_sentencepiece.model")
17
+ spm_model_file_with_bytefallback = get_tests_dir("fixtures/test_sentencepiece_with_bytefallback.model")
18
+
19
+ original_tokenizer_without_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_without_bytefallback)
20
+
21
+ with warnings.catch_warnings(record=True) as w:
22
+ _ = SpmConverter(original_tokenizer_without_bytefallback)
23
+ self.assertEqual(len(w), 0)
24
+
25
+ original_tokenizer_with_bytefallback = FakeOriginalTokenizer(vocab_file=spm_model_file_with_bytefallback)
26
+
27
+ with warnings.catch_warnings(record=True) as w:
28
+ _ = SpmConverter(original_tokenizer_with_bytefallback)
29
+ self.assertEqual(len(w), 1)
30
+
31
+ self.assertIn(
32
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
33
+ " which is not implemented in the fast tokenizers.",
34
+ str(w[0].message),
35
+ )
docs/transformers/tests/utils/test_deprecation.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import warnings
17
+
18
+ from parameterized import parameterized
19
+
20
+ from transformers import __version__, is_torch_available
21
+ from transformers.testing_utils import require_torch_gpu
22
+ from transformers.utils.deprecation import deprecate_kwarg
23
+
24
+
25
+ if is_torch_available():
26
+ import torch
27
+
28
+
29
+ INFINITE_VERSION = "9999.0.0"
30
+
31
+
32
+ class DeprecationDecoratorTester(unittest.TestCase):
33
+ def test_rename_kwarg(self):
34
+ with warnings.catch_warnings():
35
+ warnings.simplefilter("ignore")
36
+
37
+ @deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
38
+ def dummy_function(new_name=None, other_name=None):
39
+ return new_name, other_name
40
+
41
+ # Test keyword argument is renamed
42
+ value, other_value = dummy_function(deprecated_name="old_value")
43
+ self.assertEqual(value, "old_value")
44
+ self.assertIsNone(other_value)
45
+
46
+ # Test deprecated keyword argument not passed
47
+ value, other_value = dummy_function(new_name="new_value")
48
+ self.assertEqual(value, "new_value")
49
+ self.assertIsNone(other_value)
50
+
51
+ # Test other keyword argument
52
+ value, other_value = dummy_function(other_name="other_value")
53
+ self.assertIsNone(value)
54
+ self.assertEqual(other_value, "other_value")
55
+
56
+ # Test deprecated and new args are passed, the new one should be returned
57
+ value, other_value = dummy_function(deprecated_name="old_value", new_name="new_value")
58
+ self.assertEqual(value, "new_value")
59
+ self.assertIsNone(other_value)
60
+
61
+ def test_rename_multiple_kwargs(self):
62
+ with warnings.catch_warnings():
63
+ warnings.simplefilter("ignore")
64
+
65
+ @deprecate_kwarg("deprecated_name1", new_name="new_name1", version=INFINITE_VERSION)
66
+ @deprecate_kwarg("deprecated_name2", new_name="new_name2", version=INFINITE_VERSION)
67
+ def dummy_function(new_name1=None, new_name2=None, other_name=None):
68
+ return new_name1, new_name2, other_name
69
+
70
+ # Test keyword argument is renamed
71
+ value1, value2, other_value = dummy_function(deprecated_name1="old_value1", deprecated_name2="old_value2")
72
+ self.assertEqual(value1, "old_value1")
73
+ self.assertEqual(value2, "old_value2")
74
+ self.assertIsNone(other_value)
75
+
76
+ # Test deprecated keyword argument is not passed
77
+ value1, value2, other_value = dummy_function(new_name1="new_value1", new_name2="new_value2")
78
+ self.assertEqual(value1, "new_value1")
79
+ self.assertEqual(value2, "new_value2")
80
+ self.assertIsNone(other_value)
81
+
82
+ # Test other keyword argument is passed and correctly returned
83
+ value1, value2, other_value = dummy_function(other_name="other_value")
84
+ self.assertIsNone(value1)
85
+ self.assertIsNone(value2)
86
+ self.assertEqual(other_value, "other_value")
87
+
88
+ def test_warnings(self):
89
+ # Test warning is raised for future version
90
+ @deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION)
91
+ def dummy_function(new_name=None, other_name=None):
92
+ return new_name, other_name
93
+
94
+ with self.assertWarns(FutureWarning):
95
+ dummy_function(deprecated_name="old_value")
96
+
97
+ # Test warning is not raised for past version, but arg is still renamed
98
+ @deprecate_kwarg("deprecated_name", new_name="new_name", version="0.0.0")
99
+ def dummy_function(new_name=None, other_name=None):
100
+ return new_name, other_name
101
+
102
+ with warnings.catch_warnings(record=True) as raised_warnings:
103
+ warnings.simplefilter("always")
104
+
105
+ value, other_value = dummy_function(deprecated_name="old_value")
106
+
107
+ self.assertEqual(value, "old_value")
108
+ self.assertIsNone(other_value)
109
+ self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
110
+
111
+ # Test warning is raised for future version if warn_if_greater_or_equal_version is set
112
+ @deprecate_kwarg("deprecated_name", version="0.0.0", warn_if_greater_or_equal_version=True)
113
+ def dummy_function(deprecated_name=None):
114
+ return deprecated_name
115
+
116
+ with self.assertWarns(FutureWarning):
117
+ value = dummy_function(deprecated_name="deprecated_value")
118
+ self.assertEqual(value, "deprecated_value")
119
+
120
+ # Test arg is not renamed if new_name is not specified, but warning is raised
121
+ @deprecate_kwarg("deprecated_name", version=INFINITE_VERSION)
122
+ def dummy_function(deprecated_name=None):
123
+ return deprecated_name
124
+
125
+ with self.assertWarns(FutureWarning):
126
+ value = dummy_function(deprecated_name="deprecated_value")
127
+ self.assertEqual(value, "deprecated_value")
128
+
129
+ def test_raises(self):
130
+ # Test if deprecated name and new name are both passed and raise_if_both_names is set -> raise error
131
+ @deprecate_kwarg("deprecated_name", new_name="new_name", version=INFINITE_VERSION, raise_if_both_names=True)
132
+ def dummy_function(new_name=None, other_name=None):
133
+ return new_name, other_name
134
+
135
+ with self.assertRaises(ValueError):
136
+ dummy_function(deprecated_name="old_value", new_name="new_value")
137
+
138
+ # Test for current version == deprecation version
139
+ @deprecate_kwarg("deprecated_name", version=__version__, raise_if_greater_or_equal_version=True)
140
+ def dummy_function(deprecated_name=None):
141
+ return deprecated_name
142
+
143
+ with self.assertRaises(ValueError):
144
+ dummy_function(deprecated_name="old_value")
145
+
146
+ # Test for current version > deprecation version
147
+ @deprecate_kwarg("deprecated_name", version="0.0.0", raise_if_greater_or_equal_version=True)
148
+ def dummy_function(deprecated_name=None):
149
+ return deprecated_name
150
+
151
+ with self.assertRaises(ValueError):
152
+ dummy_function(deprecated_name="old_value")
153
+
154
+ def test_additional_message(self):
155
+ # Test additional message is added to the warning
156
+ @deprecate_kwarg("deprecated_name", version=INFINITE_VERSION, additional_message="Additional message")
157
+ def dummy_function(deprecated_name=None):
158
+ return deprecated_name
159
+
160
+ with warnings.catch_warnings(record=True) as raised_warnings:
161
+ warnings.simplefilter("always")
162
+ dummy_function(deprecated_name="old_value")
163
+
164
+ self.assertTrue("Additional message" in str(raised_warnings[0].message))
165
+
166
+ @parameterized.expand(["0.0.0", __version__, INFINITE_VERSION])
167
+ def test_warning_for_both_names(self, version):
168
+ # We should raise warning if both names are passed for any specified version
169
+ @deprecate_kwarg("deprecated_name", new_name="new_name", version=version)
170
+ def dummy_function(new_name=None, **kwargs):
171
+ return new_name
172
+
173
+ with self.assertWarns(FutureWarning):
174
+ result = dummy_function(deprecated_name="old_value", new_name="new_value")
175
+ self.assertEqual(result, "new_value")
176
+
177
+ @require_torch_gpu
178
+ def test_compile_safe(self):
179
+ @deprecate_kwarg("deprecated_factor", new_name="new_factor", version=INFINITE_VERSION)
180
+ def dummy_function(new_factor=None, **kwargs):
181
+ return new_factor * torch.ones(1, device="cuda")
182
+
183
+ compiled_function = torch.compile(dummy_function, fullgraph=True)
184
+
185
+ # Check that we can correctly call the compiled function with the old name, without raising errors
186
+ out = compiled_function(deprecated_factor=2)
187
+ self.assertEqual(out.item(), 2)
188
+
189
+ # Check that we can correctly call the compiled function with the new name, without raising errors
190
+ out = compiled_function(new_factor=2)
191
+ self.assertEqual(out.item(), 2)
192
+
193
+ # Check that we can correctly call the compiled function with both names, without raising errors
194
+ out = compiled_function(new_factor=2, deprecated_factor=10)
195
+ self.assertEqual(out.item(), 2)
docs/transformers/tests/utils/test_doc_samples.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import doctest
15
+ import logging
16
+ import os
17
+ import unittest
18
+ from pathlib import Path
19
+ from typing import Union
20
+
21
+ import transformers
22
+ from transformers.testing_utils import require_tf, require_torch, slow
23
+
24
+
25
+ logger = logging.getLogger()
26
+
27
+
28
+ @unittest.skip(reason="Temporarily disable the doc tests.")
29
+ @require_torch
30
+ @require_tf
31
+ @slow
32
+ class TestCodeExamples(unittest.TestCase):
33
+ def analyze_directory(
34
+ self,
35
+ directory: Path,
36
+ identifier: Union[str, None] = None,
37
+ ignore_files: Union[list[str], None] = None,
38
+ n_identifier: Union[str, list[str], None] = None,
39
+ only_modules: bool = True,
40
+ ):
41
+ """
42
+ Runs through the specific directory, looking for the files identified with `identifier`. Executes
43
+ the doctests in those files
44
+
45
+ Args:
46
+ directory (`Path`): Directory containing the files
47
+ identifier (`str`): Will parse files containing this
48
+ ignore_files (`List[str]`): List of files to skip
49
+ n_identifier (`str` or `List[str]`): Will not parse files containing this/these identifiers.
50
+ only_modules (`bool`): Whether to only analyze modules
51
+ """
52
+ files = [file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
53
+
54
+ if identifier is not None:
55
+ files = [file for file in files if identifier in file]
56
+
57
+ if n_identifier is not None:
58
+ if isinstance(n_identifier, list):
59
+ for n_ in n_identifier:
60
+ files = [file for file in files if n_ not in file]
61
+ else:
62
+ files = [file for file in files if n_identifier not in file]
63
+
64
+ ignore_files = ignore_files or []
65
+ ignore_files.append("__init__.py")
66
+ files = [file for file in files if file not in ignore_files]
67
+
68
+ for file in files:
69
+ # Open all files
70
+ print("Testing", file)
71
+
72
+ if only_modules:
73
+ module_identifier = file.split(".")[0]
74
+ try:
75
+ module_identifier = getattr(transformers, module_identifier)
76
+ suite = doctest.DocTestSuite(module_identifier)
77
+ result = unittest.TextTestRunner().run(suite)
78
+ self.assertIs(len(result.failures), 0)
79
+ except AttributeError:
80
+ logger.info(f"{module_identifier} is not a module.")
81
+ else:
82
+ result = doctest.testfile(str(".." / directory / file), optionflags=doctest.ELLIPSIS)
83
+ self.assertIs(result.failed, 0)
84
+
85
+ def test_modeling_examples(self):
86
+ transformers_directory = Path("src/transformers")
87
+ files = "modeling"
88
+ ignore_files = [
89
+ "modeling_ctrl.py",
90
+ "modeling_tf_ctrl.py",
91
+ ]
92
+ self.analyze_directory(transformers_directory, identifier=files, ignore_files=ignore_files)
93
+
94
+ def test_tokenization_examples(self):
95
+ transformers_directory = Path("src/transformers")
96
+ files = "tokenization"
97
+ self.analyze_directory(transformers_directory, identifier=files)
98
+
99
+ def test_configuration_examples(self):
100
+ transformers_directory = Path("src/transformers")
101
+ files = "configuration"
102
+ self.analyze_directory(transformers_directory, identifier=files)
103
+
104
+ def test_remaining_examples(self):
105
+ transformers_directory = Path("src/transformers")
106
+ n_identifiers = ["configuration", "modeling", "tokenization"]
107
+ self.analyze_directory(transformers_directory, n_identifier=n_identifiers)
108
+
109
+ def test_doc_sources(self):
110
+ doc_source_directory = Path("docs/source")
111
+ ignore_files = ["favicon.ico"]
112
+ self.analyze_directory(doc_source_directory, ignore_files=ignore_files, only_modules=False)
docs/transformers/tests/utils/test_dynamic_module_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ import pytest
18
+
19
+ from transformers.dynamic_module_utils import get_imports
20
+
21
+
22
+ TOP_LEVEL_IMPORT = """
23
+ import os
24
+ """
25
+
26
+ IMPORT_IN_FUNCTION = """
27
+ def foo():
28
+ import os
29
+ return False
30
+ """
31
+
32
+ DEEPLY_NESTED_IMPORT = """
33
+ def foo():
34
+ def bar():
35
+ if True:
36
+ import os
37
+ return False
38
+ return bar()
39
+ """
40
+
41
+ TOP_LEVEL_TRY_IMPORT = """
42
+ import os
43
+
44
+ try:
45
+ import bar
46
+ except ImportError:
47
+ raise ValueError()
48
+ """
49
+
50
+ TRY_IMPORT_IN_FUNCTION = """
51
+ import os
52
+
53
+ def foo():
54
+ try:
55
+ import bar
56
+ except ImportError:
57
+ raise ValueError()
58
+ """
59
+
60
+ MULTIPLE_EXCEPTS_IMPORT = """
61
+ import os
62
+
63
+ try:
64
+ import bar
65
+ except (ImportError, AttributeError):
66
+ raise ValueError()
67
+ """
68
+
69
+ EXCEPT_AS_IMPORT = """
70
+ import os
71
+
72
+ try:
73
+ import bar
74
+ except ImportError as e:
75
+ raise ValueError()
76
+ """
77
+
78
+ GENERIC_EXCEPT_IMPORT = """
79
+ import os
80
+
81
+ try:
82
+ import bar
83
+ except:
84
+ raise ValueError()
85
+ """
86
+
87
+ MULTILINE_TRY_IMPORT = """
88
+ import os
89
+
90
+ try:
91
+ import bar
92
+ import baz
93
+ except ImportError:
94
+ raise ValueError()
95
+ """
96
+
97
+ MULTILINE_BOTH_IMPORT = """
98
+ import os
99
+
100
+ try:
101
+ import bar
102
+ import baz
103
+ except ImportError:
104
+ x = 1
105
+ raise ValueError()
106
+ """
107
+
108
+ CASES = [
109
+ TOP_LEVEL_IMPORT,
110
+ IMPORT_IN_FUNCTION,
111
+ DEEPLY_NESTED_IMPORT,
112
+ TOP_LEVEL_TRY_IMPORT,
113
+ GENERIC_EXCEPT_IMPORT,
114
+ MULTILINE_TRY_IMPORT,
115
+ MULTILINE_BOTH_IMPORT,
116
+ MULTIPLE_EXCEPTS_IMPORT,
117
+ EXCEPT_AS_IMPORT,
118
+ TRY_IMPORT_IN_FUNCTION,
119
+ ]
120
+
121
+
122
+ @pytest.mark.parametrize("case", CASES)
123
+ def test_import_parsing(tmp_path, case):
124
+ tmp_file_path = os.path.join(tmp_path, "test_file.py")
125
+ with open(tmp_file_path, "w") as _tmp_file:
126
+ _tmp_file.write(case)
127
+
128
+ parsed_imports = get_imports(tmp_file_path)
129
+ assert parsed_imports == ["os"]
docs/transformers/tests/utils/test_expectations.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+
3
+ from transformers.testing_utils import Expectations
4
+
5
+
6
+ class ExpectationsTest(unittest.TestCase):
7
+ def test_expectations(self):
8
+ expectations = Expectations(
9
+ {
10
+ (None, None): 1,
11
+ ("cuda", 8): 2,
12
+ ("cuda", 7): 3,
13
+ ("rocm", 8): 4,
14
+ ("rocm", None): 5,
15
+ ("cpu", None): 6,
16
+ ("xpu", 3): 7,
17
+ }
18
+ )
19
+
20
+ def check(value, key):
21
+ assert expectations.find_expectation(key) == value
22
+
23
+ # npu has no matches so should find default expectation
24
+ check(1, ("npu", None))
25
+ check(7, ("xpu", 3))
26
+ check(2, ("cuda", 8))
27
+ check(3, ("cuda", 7))
28
+ check(4, ("rocm", 9))
29
+ check(4, ("rocm", None))
30
+ check(2, ("cuda", 2))
31
+
32
+ expectations = Expectations({("cuda", 8): 1})
33
+ with self.assertRaises(ValueError):
34
+ expectations.find_expectation(("xpu", None))
docs/transformers/tests/utils/test_feature_extraction_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import sys
17
+ import tempfile
18
+ import unittest
19
+ import unittest.mock as mock
20
+ from pathlib import Path
21
+
22
+ from huggingface_hub import HfFolder
23
+ from requests.exceptions import HTTPError
24
+
25
+ from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
26
+ from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
27
+
28
+
29
+ sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
30
+
31
+ from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
32
+
33
+
34
+ SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
35
+
36
+
37
+ class FeatureExtractorUtilTester(unittest.TestCase):
38
+ def test_cached_files_are_used_when_internet_is_down(self):
39
+ # A mock response for an HTTP head request to emulate server down
40
+ response_mock = mock.Mock()
41
+ response_mock.status_code = 500
42
+ response_mock.headers = {}
43
+ response_mock.raise_for_status.side_effect = HTTPError
44
+ response_mock.json.return_value = {}
45
+
46
+ # Download this model to make sure it's in the cache.
47
+ _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
48
+ # Under the mock environment we get a 500 error when trying to reach the model.
49
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
50
+ _ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
51
+ # This check we did call the fake head request
52
+ mock_head.assert_called()
53
+
54
+
55
+ @is_staging_test
56
+ class FeatureExtractorPushToHubTester(unittest.TestCase):
57
+ @classmethod
58
+ def setUpClass(cls):
59
+ cls._token = TOKEN
60
+ HfFolder.save_token(TOKEN)
61
+
62
+ def test_push_to_hub(self):
63
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
64
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
65
+ feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
66
+
67
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
68
+ for k, v in feature_extractor.__dict__.items():
69
+ self.assertEqual(v, getattr(new_feature_extractor, k))
70
+
71
+ def test_push_to_hub_via_save_pretrained(self):
72
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
73
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
74
+ # Push to hub via save_pretrained
75
+ with tempfile.TemporaryDirectory() as tmp_dir:
76
+ feature_extractor.save_pretrained(
77
+ tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
78
+ )
79
+
80
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
81
+ for k, v in feature_extractor.__dict__.items():
82
+ self.assertEqual(v, getattr(new_feature_extractor, k))
83
+
84
+ def test_push_to_hub_in_organization(self):
85
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
86
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
87
+ feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
88
+
89
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
90
+ for k, v in feature_extractor.__dict__.items():
91
+ self.assertEqual(v, getattr(new_feature_extractor, k))
92
+
93
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
94
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
95
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
96
+ # Push to hub via save_pretrained
97
+ with tempfile.TemporaryDirectory() as tmp_dir:
98
+ feature_extractor.save_pretrained(
99
+ tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token
100
+ )
101
+
102
+ new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(tmp_repo.repo_id)
103
+ for k, v in feature_extractor.__dict__.items():
104
+ self.assertEqual(v, getattr(new_feature_extractor, k))
105
+
106
+ def test_push_to_hub_dynamic_feature_extractor(self):
107
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
108
+ CustomFeatureExtractor.register_for_auto_class()
109
+ feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
110
+
111
+ feature_extractor.push_to_hub(tmp_repo.repo_id, token=self._token)
112
+
113
+ # This has added the proper auto_map field to the config
114
+ self.assertDictEqual(
115
+ feature_extractor.auto_map,
116
+ {"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
117
+ )
118
+
119
+ new_feature_extractor = AutoFeatureExtractor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
120
+ # Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
121
+ self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
docs/transformers/tests/utils/test_file_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import importlib
17
+ import io
18
+ import unittest
19
+
20
+ import transformers
21
+
22
+ # Try to import everything from transformers to ensure every object can be loaded.
23
+ from transformers import * # noqa F406
24
+ from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER, require_flax, require_tf, require_torch
25
+ from transformers.utils import ContextManagers, find_labels, is_flax_available, is_tf_available, is_torch_available
26
+
27
+
28
+ if is_torch_available():
29
+ from transformers import BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification
30
+
31
+ if is_tf_available():
32
+ from transformers import TFBertForPreTraining, TFBertForQuestionAnswering, TFBertForSequenceClassification
33
+
34
+ if is_flax_available():
35
+ from transformers import FlaxBertForPreTraining, FlaxBertForQuestionAnswering, FlaxBertForSequenceClassification
36
+
37
+
38
+ MODEL_ID = DUMMY_UNKNOWN_IDENTIFIER
39
+ # An actual model hosted on huggingface.co
40
+
41
+ REVISION_ID_DEFAULT = "main"
42
+ # Default branch name
43
+ REVISION_ID_ONE_SPECIFIC_COMMIT = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
44
+ # One particular commit (not the top of `main`)
45
+ REVISION_ID_INVALID = "aaaaaaa"
46
+ # This commit does not exist, so we should 404.
47
+
48
+ PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684"
49
+ # Sha-1 of config.json on the top of `main`, for checking purposes
50
+ PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3"
51
+ # Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes
52
+
53
+
54
+ # Dummy contexts to test `ContextManagers`
55
+ @contextlib.contextmanager
56
+ def context_en():
57
+ print("Welcome!")
58
+ yield
59
+ print("Bye!")
60
+
61
+
62
+ @contextlib.contextmanager
63
+ def context_fr():
64
+ print("Bonjour!")
65
+ yield
66
+ print("Au revoir!")
67
+
68
+
69
+ class TestImportMechanisms(unittest.TestCase):
70
+ def test_module_spec_available(self):
71
+ # If the spec is missing, importlib would not be able to import the module dynamically.
72
+ assert transformers.__spec__ is not None
73
+ assert importlib.util.find_spec("transformers") is not None
74
+
75
+
76
+ class GenericUtilTests(unittest.TestCase):
77
+ @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
78
+ def test_context_managers_no_context(self, mock_stdout):
79
+ with ContextManagers([]):
80
+ print("Transformers are awesome!")
81
+ # The print statement adds a new line at the end of the output
82
+ self.assertEqual(mock_stdout.getvalue(), "Transformers are awesome!\n")
83
+
84
+ @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
85
+ def test_context_managers_one_context(self, mock_stdout):
86
+ with ContextManagers([context_en()]):
87
+ print("Transformers are awesome!")
88
+ # The output should be wrapped with an English welcome and goodbye
89
+ self.assertEqual(mock_stdout.getvalue(), "Welcome!\nTransformers are awesome!\nBye!\n")
90
+
91
+ @unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
92
+ def test_context_managers_two_context(self, mock_stdout):
93
+ with ContextManagers([context_fr(), context_en()]):
94
+ print("Transformers are awesome!")
95
+ # The output should be wrapped with an English and French welcome and goodbye
96
+ self.assertEqual(mock_stdout.getvalue(), "Bonjour!\nWelcome!\nTransformers are awesome!\nBye!\nAu revoir!\n")
97
+
98
+ @require_torch
99
+ def test_find_labels_pt(self):
100
+ self.assertEqual(find_labels(BertForSequenceClassification), ["labels"])
101
+ self.assertEqual(find_labels(BertForPreTraining), ["labels", "next_sentence_label"])
102
+ self.assertEqual(find_labels(BertForQuestionAnswering), ["start_positions", "end_positions"])
103
+
104
+ # find_labels works regardless of the class name (it detects the framework through inheritance)
105
+ class DummyModel(BertForSequenceClassification):
106
+ pass
107
+
108
+ self.assertEqual(find_labels(DummyModel), ["labels"])
109
+
110
+ @require_tf
111
+ def test_find_labels_tf(self):
112
+ self.assertEqual(find_labels(TFBertForSequenceClassification), ["labels"])
113
+ self.assertEqual(find_labels(TFBertForPreTraining), ["labels", "next_sentence_label"])
114
+ self.assertEqual(find_labels(TFBertForQuestionAnswering), ["start_positions", "end_positions"])
115
+
116
+ # find_labels works regardless of the class name (it detects the framework through inheritance)
117
+ class DummyModel(TFBertForSequenceClassification):
118
+ pass
119
+
120
+ self.assertEqual(find_labels(DummyModel), ["labels"])
121
+
122
+ @require_flax
123
+ def test_find_labels_flax(self):
124
+ # Flax models don't have labels
125
+ self.assertEqual(find_labels(FlaxBertForSequenceClassification), [])
126
+ self.assertEqual(find_labels(FlaxBertForPreTraining), [])
127
+ self.assertEqual(find_labels(FlaxBertForQuestionAnswering), [])
128
+
129
+ # find_labels works regardless of the class name (it detects the framework through inheritance)
130
+ class DummyModel(FlaxBertForSequenceClassification):
131
+ pass
132
+
133
+ self.assertEqual(find_labels(DummyModel), [])
docs/transformers/tests/utils/test_generic.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import warnings
17
+
18
+ import numpy as np
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.modeling_outputs import BaseModelOutput
22
+ from transformers.testing_utils import require_flax, require_tf, require_torch
23
+ from transformers.utils import (
24
+ can_return_tuple,
25
+ expand_dims,
26
+ filter_out_non_signature_kwargs,
27
+ flatten_dict,
28
+ is_flax_available,
29
+ is_tf_available,
30
+ is_torch_available,
31
+ reshape,
32
+ squeeze,
33
+ to_py_obj,
34
+ transpose,
35
+ )
36
+
37
+
38
+ if is_flax_available():
39
+ import jax.numpy as jnp
40
+
41
+ if is_tf_available():
42
+ import tensorflow as tf
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+
48
+ class GenericTester(unittest.TestCase):
49
+ def test_flatten_dict(self):
50
+ input_dict = {
51
+ "task_specific_params": {
52
+ "summarization": {"length_penalty": 1.0, "max_length": 128, "min_length": 12, "num_beams": 4},
53
+ "summarization_cnn": {"length_penalty": 2.0, "max_length": 142, "min_length": 56, "num_beams": 4},
54
+ "summarization_xsum": {"length_penalty": 1.0, "max_length": 62, "min_length": 11, "num_beams": 6},
55
+ }
56
+ }
57
+ expected_dict = {
58
+ "task_specific_params.summarization.length_penalty": 1.0,
59
+ "task_specific_params.summarization.max_length": 128,
60
+ "task_specific_params.summarization.min_length": 12,
61
+ "task_specific_params.summarization.num_beams": 4,
62
+ "task_specific_params.summarization_cnn.length_penalty": 2.0,
63
+ "task_specific_params.summarization_cnn.max_length": 142,
64
+ "task_specific_params.summarization_cnn.min_length": 56,
65
+ "task_specific_params.summarization_cnn.num_beams": 4,
66
+ "task_specific_params.summarization_xsum.length_penalty": 1.0,
67
+ "task_specific_params.summarization_xsum.max_length": 62,
68
+ "task_specific_params.summarization_xsum.min_length": 11,
69
+ "task_specific_params.summarization_xsum.num_beams": 6,
70
+ }
71
+
72
+ self.assertEqual(flatten_dict(input_dict), expected_dict)
73
+
74
+ def test_transpose_numpy(self):
75
+ x = np.random.randn(3, 4)
76
+ self.assertTrue(np.allclose(transpose(x), x.transpose()))
77
+
78
+ x = np.random.randn(3, 4, 5)
79
+ self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
80
+
81
+ @require_torch
82
+ def test_transpose_torch(self):
83
+ x = np.random.randn(3, 4)
84
+ t = torch.tensor(x)
85
+ self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
86
+
87
+ x = np.random.randn(3, 4, 5)
88
+ t = torch.tensor(x)
89
+ self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
90
+
91
+ @require_tf
92
+ def test_transpose_tf(self):
93
+ x = np.random.randn(3, 4)
94
+ t = tf.constant(x)
95
+ self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
96
+
97
+ x = np.random.randn(3, 4, 5)
98
+ t = tf.constant(x)
99
+ self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
100
+
101
+ @require_flax
102
+ def test_transpose_flax(self):
103
+ x = np.random.randn(3, 4)
104
+ t = jnp.array(x)
105
+ self.assertTrue(np.allclose(transpose(x), np.asarray(transpose(t))))
106
+
107
+ x = np.random.randn(3, 4, 5)
108
+ t = jnp.array(x)
109
+ self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), np.asarray(transpose(t, axes=(1, 2, 0)))))
110
+
111
+ def test_reshape_numpy(self):
112
+ x = np.random.randn(3, 4)
113
+ self.assertTrue(np.allclose(reshape(x, (4, 3)), np.reshape(x, (4, 3))))
114
+
115
+ x = np.random.randn(3, 4, 5)
116
+ self.assertTrue(np.allclose(reshape(x, (12, 5)), np.reshape(x, (12, 5))))
117
+
118
+ @require_torch
119
+ def test_reshape_torch(self):
120
+ x = np.random.randn(3, 4)
121
+ t = torch.tensor(x)
122
+ self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
123
+
124
+ x = np.random.randn(3, 4, 5)
125
+ t = torch.tensor(x)
126
+ self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
127
+
128
+ @require_tf
129
+ def test_reshape_tf(self):
130
+ x = np.random.randn(3, 4)
131
+ t = tf.constant(x)
132
+ self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
133
+
134
+ x = np.random.randn(3, 4, 5)
135
+ t = tf.constant(x)
136
+ self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
137
+
138
+ @require_flax
139
+ def test_reshape_flax(self):
140
+ x = np.random.randn(3, 4)
141
+ t = jnp.array(x)
142
+ self.assertTrue(np.allclose(reshape(x, (4, 3)), np.asarray(reshape(t, (4, 3)))))
143
+
144
+ x = np.random.randn(3, 4, 5)
145
+ t = jnp.array(x)
146
+ self.assertTrue(np.allclose(reshape(x, (12, 5)), np.asarray(reshape(t, (12, 5)))))
147
+
148
+ def test_squeeze_numpy(self):
149
+ x = np.random.randn(1, 3, 4)
150
+ self.assertTrue(np.allclose(squeeze(x), np.squeeze(x)))
151
+
152
+ x = np.random.randn(1, 4, 1, 5)
153
+ self.assertTrue(np.allclose(squeeze(x, axis=2), np.squeeze(x, axis=2)))
154
+
155
+ @require_torch
156
+ def test_squeeze_torch(self):
157
+ x = np.random.randn(1, 3, 4)
158
+ t = torch.tensor(x)
159
+ self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
160
+
161
+ x = np.random.randn(1, 4, 1, 5)
162
+ t = torch.tensor(x)
163
+ self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
164
+
165
+ @require_tf
166
+ def test_squeeze_tf(self):
167
+ x = np.random.randn(1, 3, 4)
168
+ t = tf.constant(x)
169
+ self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
170
+
171
+ x = np.random.randn(1, 4, 1, 5)
172
+ t = tf.constant(x)
173
+ self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
174
+
175
+ @require_flax
176
+ def test_squeeze_flax(self):
177
+ x = np.random.randn(1, 3, 4)
178
+ t = jnp.array(x)
179
+ self.assertTrue(np.allclose(squeeze(x), np.asarray(squeeze(t))))
180
+
181
+ x = np.random.randn(1, 4, 1, 5)
182
+ t = jnp.array(x)
183
+ self.assertTrue(np.allclose(squeeze(x, axis=2), np.asarray(squeeze(t, axis=2))))
184
+
185
+ def test_expand_dims_numpy(self):
186
+ x = np.random.randn(3, 4)
187
+ self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
188
+
189
+ @require_torch
190
+ def test_expand_dims_torch(self):
191
+ x = np.random.randn(3, 4)
192
+ t = torch.tensor(x)
193
+ self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
194
+
195
+ @require_tf
196
+ def test_expand_dims_tf(self):
197
+ x = np.random.randn(3, 4)
198
+ t = tf.constant(x)
199
+ self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
200
+
201
+ @require_flax
202
+ def test_expand_dims_flax(self):
203
+ x = np.random.randn(3, 4)
204
+ t = jnp.array(x)
205
+ self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
206
+
207
+ def test_to_py_obj_native(self):
208
+ self.assertTrue(to_py_obj(1) == 1)
209
+ self.assertTrue(to_py_obj([1, 2, 3]) == [1, 2, 3])
210
+ self.assertTrue(to_py_obj([((1.0, 1.1), 1.2), (2, 3)]) == [[[1.0, 1.1], 1.2], [2, 3]])
211
+
212
+ def test_to_py_obj_numpy(self):
213
+ x1 = [[1, 2, 3], [4, 5, 6]]
214
+ t1 = np.array(x1)
215
+ self.assertTrue(to_py_obj(t1) == x1)
216
+
217
+ x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
218
+ t2 = np.array(x2)
219
+ self.assertTrue(to_py_obj(t2) == x2)
220
+
221
+ self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
222
+
223
+ @require_torch
224
+ def test_to_py_obj_torch(self):
225
+ x1 = [[1, 2, 3], [4, 5, 6]]
226
+ t1 = torch.tensor(x1)
227
+ self.assertTrue(to_py_obj(t1) == x1)
228
+
229
+ x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
230
+ t2 = torch.tensor(x2)
231
+ self.assertTrue(to_py_obj(t2) == x2)
232
+
233
+ self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
234
+
235
+ @require_tf
236
+ def test_to_py_obj_tf(self):
237
+ x1 = [[1, 2, 3], [4, 5, 6]]
238
+ t1 = tf.constant(x1)
239
+ self.assertTrue(to_py_obj(t1) == x1)
240
+
241
+ x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
242
+ t2 = tf.constant(x2)
243
+ self.assertTrue(to_py_obj(t2) == x2)
244
+
245
+ self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
246
+
247
+ @require_flax
248
+ def test_to_py_obj_flax(self):
249
+ x1 = [[1, 2, 3], [4, 5, 6]]
250
+ t1 = jnp.array(x1)
251
+ self.assertTrue(to_py_obj(t1) == x1)
252
+
253
+ x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
254
+ t2 = jnp.array(x2)
255
+ self.assertTrue(to_py_obj(t2) == x2)
256
+
257
+ self.assertTrue(to_py_obj([t1, t2]) == [x1, x2])
258
+
259
+ @require_torch
260
+ @require_tf
261
+ @require_flax
262
+ def test_to_py_obj_mixed(self):
263
+ x1 = [[1], [2]]
264
+ t1 = np.array(x1)
265
+
266
+ x2 = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
267
+ t2 = torch.tensor(x2)
268
+
269
+ x3 = [1, 2, 3]
270
+ t3 = tf.constant(x3)
271
+
272
+ x4 = [[[1.0, 2.0]]]
273
+ t4 = jnp.array(x4)
274
+
275
+ mixed = [(t1, t2), (t3, t4)]
276
+ self.assertTrue(to_py_obj(mixed) == [[x1, x2], [x3, x4]])
277
+
278
+
279
+ class ValidationDecoratorTester(unittest.TestCase):
280
+ def test_cases_no_warning(self):
281
+ with warnings.catch_warnings(record=True) as raised_warnings:
282
+ warnings.simplefilter("always")
283
+
284
+ # basic test
285
+ @filter_out_non_signature_kwargs()
286
+ def func1(a):
287
+ return a
288
+
289
+ result = func1(1)
290
+ self.assertEqual(result, 1)
291
+
292
+ # include extra kwarg
293
+ @filter_out_non_signature_kwargs(extra=["extra_arg"])
294
+ def func2(a, **kwargs):
295
+ return a, kwargs
296
+
297
+ a, kwargs = func2(1)
298
+ self.assertEqual(a, 1)
299
+ self.assertEqual(kwargs, {})
300
+
301
+ a, kwargs = func2(1, extra_arg=2)
302
+ self.assertEqual(a, 1)
303
+ self.assertEqual(kwargs, {"extra_arg": 2})
304
+
305
+ # multiple extra kwargs
306
+ @filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
307
+ def func3(a, **kwargs):
308
+ return a, kwargs
309
+
310
+ a, kwargs = func3(2)
311
+ self.assertEqual(a, 2)
312
+ self.assertEqual(kwargs, {})
313
+
314
+ a, kwargs = func3(3, extra_arg2=3)
315
+ self.assertEqual(a, 3)
316
+ self.assertEqual(kwargs, {"extra_arg2": 3})
317
+
318
+ a, kwargs = func3(1, extra_arg=2, extra_arg2=3)
319
+ self.assertEqual(a, 1)
320
+ self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
321
+
322
+ # Check that no warnings were raised
323
+ self.assertEqual(len(raised_warnings), 0, f"Warning raised: {[w.message for w in raised_warnings]}")
324
+
325
+ def test_cases_with_warnings(self):
326
+ @filter_out_non_signature_kwargs()
327
+ def func1(a):
328
+ return a
329
+
330
+ with self.assertWarns(UserWarning):
331
+ func1(1, extra_arg=2)
332
+
333
+ @filter_out_non_signature_kwargs(extra=["extra_arg"])
334
+ def func2(a, **kwargs):
335
+ return kwargs
336
+
337
+ with self.assertWarns(UserWarning):
338
+ kwargs = func2(1, extra_arg=2, extra_arg2=3)
339
+ self.assertEqual(kwargs, {"extra_arg": 2})
340
+
341
+ @filter_out_non_signature_kwargs(extra=["extra_arg", "extra_arg2"])
342
+ def func3(a, **kwargs):
343
+ return kwargs
344
+
345
+ with self.assertWarns(UserWarning):
346
+ kwargs = func3(1, extra_arg=2, extra_arg2=3, extra_arg3=4)
347
+ self.assertEqual(kwargs, {"extra_arg": 2, "extra_arg2": 3})
348
+
349
+
350
+ @require_torch
351
+ class CanReturnTupleDecoratorTester(unittest.TestCase):
352
+ def _get_model(self, config, store_config=True, raise_in_forward=False):
353
+ # Simple model class for testing can_return_tuple decorator.
354
+ class SimpleTestModel(torch.nn.Module):
355
+ def __init__(self, config):
356
+ super().__init__()
357
+ if store_config:
358
+ self.config = config
359
+
360
+ @can_return_tuple
361
+ def forward(self, x):
362
+ if raise_in_forward:
363
+ raise ValueError("Test error")
364
+ return BaseModelOutput(
365
+ last_hidden_state=x,
366
+ hidden_states=None,
367
+ attentions=None,
368
+ )
369
+
370
+ return SimpleTestModel(config)
371
+
372
+ def test_decorator_eager(self):
373
+ """Test that the can_return_tuple decorator works with eager mode."""
374
+
375
+ # test nothing is set
376
+ config = PretrainedConfig()
377
+ model = self._get_model(config)
378
+ inputs = torch.tensor(10)
379
+ output = model(inputs)
380
+ self.assertIsInstance(
381
+ output, BaseModelOutput, "output should be a BaseModelOutput when return_dict is not set"
382
+ )
383
+
384
+ # test all explicit cases
385
+ for config_return_dict in [True, False, None]:
386
+ for return_dict in [True, False, None]:
387
+ config = PretrainedConfig(return_dict=config_return_dict)
388
+ model = self._get_model(config)
389
+ output = model(torch.tensor(10), return_dict=return_dict)
390
+
391
+ expected_type = tuple if config_return_dict is False or return_dict is False else BaseModelOutput
392
+ message = f"output should be a {expected_type.__name__} when config.use_return_dict={config_return_dict} and return_dict={return_dict}"
393
+ self.assertIsInstance(output, expected_type, message)
394
+
395
+ def test_decorator_compiled(self):
396
+ """Test that the can_return_tuple decorator works with compiled mode."""
397
+ config = PretrainedConfig()
398
+
399
+ # Output object
400
+ model = self._get_model(config)
401
+ compiled_model = torch.compile(model)
402
+ output = compiled_model(torch.tensor(10))
403
+ self.assertIsInstance(output, BaseModelOutput)
404
+
405
+ # Tuple output
406
+ model = self._get_model(config)
407
+ compiled_model = torch.compile(model)
408
+ output = compiled_model(torch.tensor(10), return_dict=False)
409
+ self.assertIsInstance(output, tuple)
410
+
411
+ def test_decorator_torch_export(self):
412
+ """Test that the can_return_tuple decorator works with torch.export."""
413
+ config = PretrainedConfig()
414
+ model = self._get_model(config)
415
+ torch.export.export(model, args=(torch.tensor(10),))
416
+
417
+ def test_decorator_torchscript(self):
418
+ """Test that the can_return_tuple decorator works with torch.jit.trace."""
419
+ config = PretrainedConfig(return_dict=False)
420
+ model = self._get_model(config)
421
+ inputs = torch.tensor(10)
422
+ traced_module = torch.jit.trace(model, inputs)
423
+ output = traced_module(inputs)
424
+ self.assertIsInstance(output, tuple)
425
+
426
+ def test_attribute_cleanup(self):
427
+ """Test that the `_is_top_level_module` attribute is removed after the forward call."""
428
+
429
+ config = PretrainedConfig(return_dict=False)
430
+ inputs = torch.tensor(10)
431
+
432
+ # working case
433
+ model = self._get_model(config)
434
+ output = model(inputs)
435
+
436
+ self.assertIsInstance(output, tuple)
437
+ for name, module in model.named_modules():
438
+ self.assertFalse(
439
+ hasattr(module, "_is_top_level_module"),
440
+ f"Module `{name}` should not have `_is_top_level_module` attribute",
441
+ )
442
+
443
+ # model without config
444
+ no_config_model = self._get_model(config, store_config=False)
445
+ output = no_config_model(inputs)
446
+
447
+ self.assertIsInstance(output, BaseModelOutput)
448
+ for name, module in no_config_model.named_modules():
449
+ self.assertFalse(
450
+ hasattr(module, "_is_top_level_module"),
451
+ f"Module `{name}` should not have `_is_top_level_module` attribute",
452
+ )
453
+
454
+ # model with raise in forward
455
+ model_with_raise = self._get_model(config, raise_in_forward=True)
456
+ with self.assertRaises(ValueError):
457
+ model_with_raise(inputs)
458
+
459
+ for name, module in model_with_raise.named_modules():
460
+ self.assertFalse(
461
+ hasattr(module, "_is_top_level_module"),
462
+ f"Module `{name}` should not have `_is_top_level_module` attribute",
463
+ )
docs/transformers/tests/utils/test_hf_argparser.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import json
17
+ import os
18
+ import sys
19
+ import tempfile
20
+ import unittest
21
+ from argparse import Namespace
22
+ from dataclasses import dataclass, field
23
+ from enum import Enum
24
+ from pathlib import Path
25
+ from typing import List, Literal, Optional, Union, get_args, get_origin
26
+
27
+ import yaml
28
+
29
+ from transformers import HfArgumentParser, TrainingArguments
30
+ from transformers.hf_argparser import make_choice_type_function, string_to_bool
31
+ from transformers.testing_utils import require_torch
32
+
33
+
34
+ # Since Python 3.10, we can use the builtin `|` operator for Union types
35
+ # See PEP 604: https://peps.python.org/pep-0604
36
+ is_python_no_less_than_3_10 = sys.version_info >= (3, 10)
37
+
38
+
39
+ def list_field(default=None, metadata=None):
40
+ return field(default_factory=lambda: default, metadata=metadata)
41
+
42
+
43
+ @dataclass
44
+ class BasicExample:
45
+ foo: int
46
+ bar: float
47
+ baz: str
48
+ flag: bool
49
+
50
+
51
+ @dataclass
52
+ class WithDefaultExample:
53
+ foo: int = 42
54
+ baz: str = field(default="toto", metadata={"help": "help message"})
55
+
56
+
57
+ @dataclass
58
+ class WithDefaultBoolExample:
59
+ foo: bool = False
60
+ baz: bool = True
61
+ opt: Optional[bool] = None
62
+
63
+
64
+ class BasicEnum(Enum):
65
+ titi = "titi"
66
+ toto = "toto"
67
+
68
+
69
+ class MixedTypeEnum(Enum):
70
+ titi = "titi"
71
+ toto = "toto"
72
+ fourtytwo = 42
73
+
74
+
75
+ @dataclass
76
+ class EnumExample:
77
+ foo: BasicEnum = "toto"
78
+
79
+ def __post_init__(self):
80
+ self.foo = BasicEnum(self.foo)
81
+
82
+
83
+ @dataclass
84
+ class MixedTypeEnumExample:
85
+ foo: MixedTypeEnum = "toto"
86
+
87
+ def __post_init__(self):
88
+ self.foo = MixedTypeEnum(self.foo)
89
+
90
+
91
+ @dataclass
92
+ class OptionalExample:
93
+ foo: Optional[int] = None
94
+ bar: Optional[float] = field(default=None, metadata={"help": "help message"})
95
+ baz: Optional[str] = None
96
+ ces: Optional[list[str]] = list_field(default=[])
97
+ des: Optional[list[int]] = list_field(default=[])
98
+
99
+
100
+ @dataclass
101
+ class ListExample:
102
+ foo_int: list[int] = list_field(default=[])
103
+ bar_int: list[int] = list_field(default=[1, 2, 3])
104
+ foo_str: list[str] = list_field(default=["Hallo", "Bonjour", "Hello"])
105
+ foo_float: list[float] = list_field(default=[0.1, 0.2, 0.3])
106
+
107
+
108
+ @dataclass
109
+ class RequiredExample:
110
+ required_list: list[int] = field()
111
+ required_str: str = field()
112
+ required_enum: BasicEnum = field()
113
+
114
+ def __post_init__(self):
115
+ self.required_enum = BasicEnum(self.required_enum)
116
+
117
+
118
+ @dataclass
119
+ class StringLiteralAnnotationExample:
120
+ foo: int
121
+ required_enum: "BasicEnum" = field()
122
+ opt: "Optional[bool]" = None
123
+ baz: "str" = field(default="toto", metadata={"help": "help message"})
124
+ foo_str: "List[str]" = list_field(default=["Hallo", "Bonjour", "Hello"])
125
+
126
+
127
+ if is_python_no_less_than_3_10:
128
+
129
+ @dataclass
130
+ class WithDefaultBoolExamplePep604:
131
+ foo: bool = False
132
+ baz: bool = True
133
+ opt: bool | None = None
134
+
135
+ @dataclass
136
+ class OptionalExamplePep604:
137
+ foo: int | None = None
138
+ bar: float | None = field(default=None, metadata={"help": "help message"})
139
+ baz: str | None = None
140
+ ces: list[str] | None = list_field(default=[])
141
+ des: list[int] | None = list_field(default=[])
142
+
143
+
144
+ class HfArgumentParserTest(unittest.TestCase):
145
+ def argparsersEqual(self, a: argparse.ArgumentParser, b: argparse.ArgumentParser):
146
+ """
147
+ Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
148
+ """
149
+ self.assertEqual(len(a._actions), len(b._actions))
150
+ for x, y in zip(a._actions, b._actions):
151
+ xx = {k: v for k, v in vars(x).items() if k != "container"}
152
+ yy = {k: v for k, v in vars(y).items() if k != "container"}
153
+
154
+ # Choices with mixed type have custom function as "type"
155
+ # So we need to compare results directly for equality
156
+ if xx.get("choices", None) and yy.get("choices", None):
157
+ for expected_choice in yy["choices"] + xx["choices"]:
158
+ self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
159
+ del xx["type"], yy["type"]
160
+
161
+ self.assertEqual(xx, yy)
162
+
163
+ def test_basic(self):
164
+ parser = HfArgumentParser(BasicExample)
165
+
166
+ expected = argparse.ArgumentParser()
167
+ expected.add_argument("--foo", type=int, required=True)
168
+ expected.add_argument("--bar", type=float, required=True)
169
+ expected.add_argument("--baz", type=str, required=True)
170
+ expected.add_argument("--flag", type=string_to_bool, default=False, const=True, nargs="?")
171
+ self.argparsersEqual(parser, expected)
172
+
173
+ args = ["--foo", "1", "--baz", "quux", "--bar", "0.5"]
174
+ (example,) = parser.parse_args_into_dataclasses(args, look_for_args_file=False)
175
+ self.assertFalse(example.flag)
176
+
177
+ def test_with_default(self):
178
+ parser = HfArgumentParser(WithDefaultExample)
179
+
180
+ expected = argparse.ArgumentParser()
181
+ expected.add_argument("--foo", default=42, type=int)
182
+ expected.add_argument("--baz", default="toto", type=str, help="help message")
183
+ self.argparsersEqual(parser, expected)
184
+
185
+ def test_with_default_bool(self):
186
+ expected = argparse.ArgumentParser()
187
+ expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
188
+ expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
189
+ # A boolean no_* argument always has to come after its "default: True" regular counter-part
190
+ # and its default must be set to False
191
+ expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
192
+ expected.add_argument("--opt", type=string_to_bool, default=None)
193
+
194
+ dataclass_types = [WithDefaultBoolExample]
195
+ if is_python_no_less_than_3_10:
196
+ dataclass_types.append(WithDefaultBoolExamplePep604)
197
+
198
+ for dataclass_type in dataclass_types:
199
+ parser = HfArgumentParser(dataclass_type)
200
+ self.argparsersEqual(parser, expected)
201
+
202
+ args = parser.parse_args([])
203
+ self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
204
+
205
+ args = parser.parse_args(["--foo", "--no_baz"])
206
+ self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
207
+
208
+ args = parser.parse_args(["--foo", "--no-baz"])
209
+ self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
210
+
211
+ args = parser.parse_args(["--foo", "--baz"])
212
+ self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
213
+
214
+ args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
215
+ self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
216
+
217
+ args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
218
+ self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
219
+
220
+ def test_with_enum(self):
221
+ parser = HfArgumentParser(MixedTypeEnumExample)
222
+
223
+ expected = argparse.ArgumentParser()
224
+ expected.add_argument(
225
+ "--foo",
226
+ default="toto",
227
+ choices=["titi", "toto", 42],
228
+ type=make_choice_type_function(["titi", "toto", 42]),
229
+ )
230
+ self.argparsersEqual(parser, expected)
231
+
232
+ args = parser.parse_args([])
233
+ self.assertEqual(args.foo, "toto")
234
+ enum_ex = parser.parse_args_into_dataclasses([])[0]
235
+ self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
236
+
237
+ args = parser.parse_args(["--foo", "titi"])
238
+ self.assertEqual(args.foo, "titi")
239
+ enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
240
+ self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
241
+
242
+ args = parser.parse_args(["--foo", "42"])
243
+ self.assertEqual(args.foo, 42)
244
+ enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
245
+ self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
246
+
247
+ def test_with_literal(self):
248
+ @dataclass
249
+ class LiteralExample:
250
+ foo: Literal["titi", "toto", 42] = "toto"
251
+
252
+ parser = HfArgumentParser(LiteralExample)
253
+
254
+ expected = argparse.ArgumentParser()
255
+ expected.add_argument(
256
+ "--foo",
257
+ default="toto",
258
+ choices=("titi", "toto", 42),
259
+ type=make_choice_type_function(["titi", "toto", 42]),
260
+ )
261
+ self.argparsersEqual(parser, expected)
262
+
263
+ args = parser.parse_args([])
264
+ self.assertEqual(args.foo, "toto")
265
+
266
+ args = parser.parse_args(["--foo", "titi"])
267
+ self.assertEqual(args.foo, "titi")
268
+
269
+ args = parser.parse_args(["--foo", "42"])
270
+ self.assertEqual(args.foo, 42)
271
+
272
+ def test_with_list(self):
273
+ parser = HfArgumentParser(ListExample)
274
+
275
+ expected = argparse.ArgumentParser()
276
+ expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
277
+ expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
278
+ expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
279
+ expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
280
+
281
+ self.argparsersEqual(parser, expected)
282
+
283
+ args = parser.parse_args([])
284
+ self.assertEqual(
285
+ args,
286
+ Namespace(foo_int=[], bar_int=[1, 2, 3], foo_str=["Hallo", "Bonjour", "Hello"], foo_float=[0.1, 0.2, 0.3]),
287
+ )
288
+
289
+ args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
290
+ self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
291
+
292
+ args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
293
+ self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
294
+
295
+ def test_with_optional(self):
296
+ expected = argparse.ArgumentParser()
297
+ expected.add_argument("--foo", default=None, type=int)
298
+ expected.add_argument("--bar", default=None, type=float, help="help message")
299
+ expected.add_argument("--baz", default=None, type=str)
300
+ expected.add_argument("--ces", nargs="+", default=[], type=str)
301
+ expected.add_argument("--des", nargs="+", default=[], type=int)
302
+
303
+ dataclass_types = [OptionalExample]
304
+ if is_python_no_less_than_3_10:
305
+ dataclass_types.append(OptionalExamplePep604)
306
+
307
+ for dataclass_type in dataclass_types:
308
+ parser = HfArgumentParser(dataclass_type)
309
+
310
+ self.argparsersEqual(parser, expected)
311
+
312
+ args = parser.parse_args([])
313
+ self.assertEqual(args, Namespace(foo=None, bar=None, baz=None, ces=[], des=[]))
314
+
315
+ args = parser.parse_args("--foo 12 --bar 3.14 --baz 42 --ces a b c --des 1 2 3".split())
316
+ self.assertEqual(args, Namespace(foo=12, bar=3.14, baz="42", ces=["a", "b", "c"], des=[1, 2, 3]))
317
+
318
+ def test_with_required(self):
319
+ parser = HfArgumentParser(RequiredExample)
320
+
321
+ expected = argparse.ArgumentParser()
322
+ expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
323
+ expected.add_argument("--required_str", "--required-str", type=str, required=True)
324
+ expected.add_argument(
325
+ "--required_enum",
326
+ "--required-enum",
327
+ type=make_choice_type_function(["titi", "toto"]),
328
+ choices=["titi", "toto"],
329
+ required=True,
330
+ )
331
+ self.argparsersEqual(parser, expected)
332
+
333
+ def test_with_string_literal_annotation(self):
334
+ parser = HfArgumentParser(StringLiteralAnnotationExample)
335
+
336
+ expected = argparse.ArgumentParser()
337
+ expected.add_argument("--foo", type=int, required=True)
338
+ expected.add_argument(
339
+ "--required_enum",
340
+ "--required-enum",
341
+ type=make_choice_type_function(["titi", "toto"]),
342
+ choices=["titi", "toto"],
343
+ required=True,
344
+ )
345
+ expected.add_argument("--opt", type=string_to_bool, default=None)
346
+ expected.add_argument("--baz", default="toto", type=str, help="help message")
347
+ expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
348
+ self.argparsersEqual(parser, expected)
349
+
350
+ def test_parse_dict(self):
351
+ parser = HfArgumentParser(BasicExample)
352
+
353
+ args_dict = {
354
+ "foo": 12,
355
+ "bar": 3.14,
356
+ "baz": "42",
357
+ "flag": True,
358
+ }
359
+
360
+ parsed_args = parser.parse_dict(args_dict)[0]
361
+ args = BasicExample(**args_dict)
362
+ self.assertEqual(parsed_args, args)
363
+
364
+ def test_parse_dict_extra_key(self):
365
+ parser = HfArgumentParser(BasicExample)
366
+
367
+ args_dict = {
368
+ "foo": 12,
369
+ "bar": 3.14,
370
+ "baz": "42",
371
+ "flag": True,
372
+ "extra": 42,
373
+ }
374
+
375
+ self.assertRaises(ValueError, parser.parse_dict, args_dict, allow_extra_keys=False)
376
+
377
+ def test_parse_json(self):
378
+ parser = HfArgumentParser(BasicExample)
379
+
380
+ args_dict_for_json = {
381
+ "foo": 12,
382
+ "bar": 3.14,
383
+ "baz": "42",
384
+ "flag": True,
385
+ }
386
+ with tempfile.TemporaryDirectory() as tmp_dir:
387
+ temp_local_path = os.path.join(tmp_dir, "temp_json")
388
+ os.mkdir(temp_local_path)
389
+ with open(temp_local_path + ".json", "w+") as f:
390
+ json.dump(args_dict_for_json, f)
391
+ parsed_args = parser.parse_json_file(Path(temp_local_path + ".json"))[0]
392
+
393
+ args = BasicExample(**args_dict_for_json)
394
+ self.assertEqual(parsed_args, args)
395
+
396
+ def test_parse_yaml(self):
397
+ parser = HfArgumentParser(BasicExample)
398
+
399
+ args_dict_for_yaml = {
400
+ "foo": 12,
401
+ "bar": 3.14,
402
+ "baz": "42",
403
+ "flag": True,
404
+ }
405
+ with tempfile.TemporaryDirectory() as tmp_dir:
406
+ temp_local_path = os.path.join(tmp_dir, "temp_yaml")
407
+ os.mkdir(temp_local_path)
408
+ with open(temp_local_path + ".yaml", "w+") as f:
409
+ yaml.dump(args_dict_for_yaml, f)
410
+ parsed_args = parser.parse_yaml_file(Path(temp_local_path + ".yaml"))[0]
411
+ args = BasicExample(**args_dict_for_yaml)
412
+ self.assertEqual(parsed_args, args)
413
+
414
+ def test_z_integration_training_args(self):
415
+ # make sure that this test executes last in the test suite
416
+ parser = HfArgumentParser(TrainingArguments)
417
+ self.assertIsNotNone(parser)
418
+
419
+ def test_valid_dict_annotation(self):
420
+ """
421
+ Tests to make sure that `dict` based annotations
422
+ are correctly made in the `TrainingArguments`.
423
+
424
+ If this fails, a type annotation change is
425
+ needed on a new input
426
+ """
427
+ base_list = TrainingArguments._VALID_DICT_FIELDS.copy()
428
+ args = TrainingArguments
429
+
430
+ # First find any annotations that contain `dict`
431
+ fields = args.__dataclass_fields__
432
+
433
+ raw_dict_fields = []
434
+ optional_dict_fields = []
435
+
436
+ for field in fields.values():
437
+ # First verify raw dict
438
+ if field.type in (dict, dict):
439
+ raw_dict_fields.append(field)
440
+ # Next check for `Union` or `Optional`
441
+ elif get_origin(field.type) == Union:
442
+ if any(arg in (dict, dict) for arg in get_args(field.type)):
443
+ optional_dict_fields.append(field)
444
+
445
+ # First check: anything in `raw_dict_fields` is very bad
446
+ self.assertEqual(
447
+ len(raw_dict_fields),
448
+ 0,
449
+ "Found invalid raw `dict` types in the `TrainingArgument` typings. "
450
+ "This leads to issues with the CLI. Please turn this into `typing.Optional[dict,str]`",
451
+ )
452
+
453
+ # Next check raw annotations
454
+ for field in optional_dict_fields:
455
+ args = get_args(field.type)
456
+ # These should be returned as `dict`, `str`, ...
457
+ # we only care about the first two
458
+ self.assertIn(args[0], (dict, dict))
459
+ self.assertEqual(
460
+ str(args[1]),
461
+ "<class 'str'>",
462
+ f"Expected field `{field.name}` to have a type signature of at least `typing.Union[dict,str,...]` for CLI compatibility, "
463
+ "but `str` not found. Please fix this.",
464
+ )
465
+
466
+ # Second check: anything in `optional_dict_fields` is bad if it's not in `base_list`
467
+ for field in optional_dict_fields:
468
+ self.assertIn(
469
+ field.name,
470
+ base_list,
471
+ f"Optional dict field `{field.name}` is not in the base list of valid fields. Please add it to `TrainingArguments._VALID_DICT_FIELDS`",
472
+ )
473
+
474
+ @require_torch
475
+ def test_valid_dict_input_parsing(self):
476
+ with tempfile.TemporaryDirectory() as tmp_dir:
477
+ args = TrainingArguments(
478
+ output_dir=tmp_dir,
479
+ accelerator_config='{"split_batches": "True", "gradient_accumulation_kwargs": {"num_steps": 2}}',
480
+ )
481
+ self.assertEqual(args.accelerator_config.split_batches, True)
482
+ self.assertEqual(args.accelerator_config.gradient_accumulation_kwargs["num_steps"], 2)
docs/transformers/tests/utils/test_hub_utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import os
16
+ import tempfile
17
+ import unittest
18
+ import unittest.mock as mock
19
+ from pathlib import Path
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ from requests.exceptions import HTTPError
23
+
24
+ from transformers.utils import (
25
+ CONFIG_NAME,
26
+ FLAX_WEIGHTS_NAME,
27
+ TF2_WEIGHTS_NAME,
28
+ TRANSFORMERS_CACHE,
29
+ WEIGHTS_NAME,
30
+ cached_file,
31
+ has_file,
32
+ )
33
+
34
+
35
+ RANDOM_BERT = "hf-internal-testing/tiny-random-bert"
36
+ TINY_BERT_PT_ONLY = "hf-internal-testing/tiny-bert-pt-only"
37
+ CACHE_DIR = os.path.join(TRANSFORMERS_CACHE, "models--hf-internal-testing--tiny-random-bert")
38
+ FULL_COMMIT_HASH = "9b8c223d42b2188cb49d29af482996f9d0f3e5a6"
39
+
40
+ GATED_REPO = "hf-internal-testing/dummy-gated-model"
41
+ README_FILE = "README.md"
42
+
43
+
44
+ class GetFromCacheTests(unittest.TestCase):
45
+ def test_cached_file(self):
46
+ archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
47
+ # Should have downloaded the file in here
48
+ self.assertTrue(os.path.isdir(CACHE_DIR))
49
+ # Cache should contain at least those three subfolders:
50
+ for subfolder in ["blobs", "refs", "snapshots"]:
51
+ self.assertTrue(os.path.isdir(os.path.join(CACHE_DIR, subfolder)))
52
+ with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
53
+ main_commit = f.read()
54
+ self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", main_commit, CONFIG_NAME))
55
+ self.assertTrue(os.path.isfile(archive_file))
56
+
57
+ # File is cached at the same place the second time.
58
+ new_archive_file = cached_file(RANDOM_BERT, CONFIG_NAME)
59
+ self.assertEqual(archive_file, new_archive_file)
60
+
61
+ # Using a specific revision to test the full commit hash.
62
+ archive_file = cached_file(RANDOM_BERT, CONFIG_NAME, revision="9b8c223")
63
+ self.assertEqual(archive_file, os.path.join(CACHE_DIR, "snapshots", FULL_COMMIT_HASH, CONFIG_NAME))
64
+
65
+ def test_cached_file_errors(self):
66
+ with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
67
+ _ = cached_file("tiny-random-bert", CONFIG_NAME)
68
+
69
+ with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
70
+ _ = cached_file(RANDOM_BERT, CONFIG_NAME, revision="aaaa")
71
+
72
+ with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
73
+ _ = cached_file(RANDOM_BERT, "conf")
74
+
75
+ def test_non_existence_is_cached(self):
76
+ with self.assertRaisesRegex(EnvironmentError, "does not appear to have a file named"):
77
+ _ = cached_file(RANDOM_BERT, "conf")
78
+
79
+ with open(os.path.join(CACHE_DIR, "refs", "main")) as f:
80
+ main_commit = f.read()
81
+ self.assertTrue(os.path.isfile(os.path.join(CACHE_DIR, ".no_exist", main_commit, "conf")))
82
+
83
+ path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_missing_entries=False)
84
+ self.assertIsNone(path)
85
+
86
+ path = cached_file(RANDOM_BERT, "conf", local_files_only=True, _raise_exceptions_for_missing_entries=False)
87
+ self.assertIsNone(path)
88
+
89
+ # Under the mock environment, hf_hub_download will always raise an HTTPError
90
+ with mock.patch("transformers.utils.hub.hf_hub_download", side_effect=HTTPError) as mock_head:
91
+ path = cached_file(RANDOM_BERT, "conf", _raise_exceptions_for_connection_errors=False)
92
+ self.assertIsNone(path)
93
+ # This check we did call the fake head request
94
+ mock_head.assert_called()
95
+
96
+ def test_has_file(self):
97
+ self.assertTrue(has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME))
98
+ self.assertFalse(has_file(TINY_BERT_PT_ONLY, TF2_WEIGHTS_NAME))
99
+ self.assertFalse(has_file(TINY_BERT_PT_ONLY, FLAX_WEIGHTS_NAME))
100
+
101
+ def test_has_file_in_cache(self):
102
+ with tempfile.TemporaryDirectory() as tmp_dir:
103
+ # Empty cache dir + offline mode => return False
104
+ assert not has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
105
+
106
+ # Populate cache dir
107
+ hf_hub_download(TINY_BERT_PT_ONLY, WEIGHTS_NAME, cache_dir=tmp_dir)
108
+
109
+ # Cache dir + offline mode => return True
110
+ assert has_file(TINY_BERT_PT_ONLY, WEIGHTS_NAME, local_files_only=True, cache_dir=tmp_dir)
111
+
112
+ def test_get_file_from_repo_distant(self):
113
+ # should return None if the file does not exist
114
+ self.assertIsNone(
115
+ cached_file(
116
+ "google-bert/bert-base-cased",
117
+ "ahah.txt",
118
+ _raise_exceptions_for_gated_repo=False,
119
+ _raise_exceptions_for_missing_entries=False,
120
+ _raise_exceptions_for_connection_errors=False,
121
+ )
122
+ )
123
+
124
+ # The function raises if the repository does not exist.
125
+ with self.assertRaisesRegex(EnvironmentError, "is not a valid model identifier"):
126
+ cached_file(
127
+ "bert-base-case",
128
+ CONFIG_NAME,
129
+ _raise_exceptions_for_gated_repo=False,
130
+ _raise_exceptions_for_missing_entries=False,
131
+ _raise_exceptions_for_connection_errors=False,
132
+ )
133
+
134
+ # The function raises if the revision does not exist.
135
+ with self.assertRaisesRegex(EnvironmentError, "is not a valid git identifier"):
136
+ cached_file(
137
+ "google-bert/bert-base-cased",
138
+ CONFIG_NAME,
139
+ revision="ahaha",
140
+ _raise_exceptions_for_gated_repo=False,
141
+ _raise_exceptions_for_missing_entries=False,
142
+ _raise_exceptions_for_connection_errors=False,
143
+ )
144
+
145
+ resolved_file = cached_file(
146
+ "google-bert/bert-base-cased",
147
+ CONFIG_NAME,
148
+ _raise_exceptions_for_gated_repo=False,
149
+ _raise_exceptions_for_missing_entries=False,
150
+ _raise_exceptions_for_connection_errors=False,
151
+ )
152
+ # The name is the cached name which is not very easy to test, so instead we load the content.
153
+ config = json.loads(open(resolved_file).read())
154
+ self.assertEqual(config["hidden_size"], 768)
155
+
156
+ def test_get_file_from_repo_local(self):
157
+ with tempfile.TemporaryDirectory() as tmp_dir:
158
+ filename = Path(tmp_dir) / "a.txt"
159
+ filename.touch()
160
+ self.assertEqual(
161
+ cached_file(
162
+ tmp_dir,
163
+ "a.txt",
164
+ _raise_exceptions_for_gated_repo=False,
165
+ _raise_exceptions_for_missing_entries=False,
166
+ _raise_exceptions_for_connection_errors=False,
167
+ ),
168
+ str(filename),
169
+ )
170
+
171
+ self.assertIsNone(
172
+ cached_file(
173
+ tmp_dir,
174
+ "b.txt",
175
+ _raise_exceptions_for_gated_repo=False,
176
+ _raise_exceptions_for_missing_entries=False,
177
+ _raise_exceptions_for_connection_errors=False,
178
+ )
179
+ )
180
+
181
+ def test_get_file_gated_repo(self):
182
+ """Test download file from a gated repo fails with correct message when not authenticated."""
183
+ with self.assertRaisesRegex(EnvironmentError, "You are trying to access a gated repo."):
184
+ # All files except README.md are protected on a gated repo.
185
+ cached_file(GATED_REPO, "gated_file.txt", token=False)
186
+
187
+ def test_has_file_gated_repo(self):
188
+ """Test check file existence from a gated repo fails with correct message when not authenticated."""
189
+ with self.assertRaisesRegex(EnvironmentError, "is a gated repository"):
190
+ # All files except README.md are protected on a gated repo.
191
+ has_file(GATED_REPO, "gated_file.txt", token=False)
192
+
193
+ def test_cached_files_exception_raised(self):
194
+ """Test that unhadled exceptions, e.g. ModuleNotFoundError, is properly re-raised by cached_files when hf_hub_download fails."""
195
+ with mock.patch(
196
+ "transformers.utils.hub.hf_hub_download", side_effect=ModuleNotFoundError("No module named 'MockModule'")
197
+ ):
198
+ with self.assertRaises(ModuleNotFoundError):
199
+ # The error should be re-raised by cached_files, not caught in the exception handling block
200
+ cached_file(RANDOM_BERT, "nonexistent.json")
docs/transformers/tests/utils/test_image_processing_utils.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import tempfile
17
+ import unittest
18
+ import unittest.mock as mock
19
+ from pathlib import Path
20
+
21
+ from huggingface_hub import HfFolder
22
+ from requests.exceptions import HTTPError
23
+
24
+ from transformers import AutoImageProcessor, ViTImageProcessor
25
+ from transformers.image_processing_utils import get_size_dict
26
+ from transformers.testing_utils import TOKEN, TemporaryHubRepo, get_tests_dir, is_staging_test
27
+
28
+
29
+ sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
30
+
31
+ from test_module.custom_image_processing import CustomImageProcessor # noqa E402
32
+
33
+
34
+ SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
35
+
36
+
37
+ class ImageProcessorUtilTester(unittest.TestCase):
38
+ def test_cached_files_are_used_when_internet_is_down(self):
39
+ # A mock response for an HTTP head request to emulate server down
40
+ response_mock = mock.Mock()
41
+ response_mock.status_code = 500
42
+ response_mock.headers = {}
43
+ response_mock.raise_for_status.side_effect = HTTPError
44
+ response_mock.json.return_value = {}
45
+
46
+ # Download this model to make sure it's in the cache.
47
+ _ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
48
+ # Under the mock environment we get a 500 error when trying to reach the model.
49
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
50
+ _ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
51
+ # This check we did call the fake head request
52
+ mock_head.assert_called()
53
+
54
+ def test_image_processor_from_pretrained_subfolder(self):
55
+ with self.assertRaises(OSError):
56
+ # config is in subfolder, the following should not work without specifying the subfolder
57
+ _ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
58
+
59
+ config = AutoImageProcessor.from_pretrained(
60
+ "hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
61
+ )
62
+
63
+ self.assertIsNotNone(config)
64
+
65
+
66
+ @is_staging_test
67
+ class ImageProcessorPushToHubTester(unittest.TestCase):
68
+ @classmethod
69
+ def setUpClass(cls):
70
+ cls._token = TOKEN
71
+ HfFolder.save_token(TOKEN)
72
+
73
+ def test_push_to_hub(self):
74
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
75
+ image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
76
+ image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
77
+
78
+ new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
79
+ for k, v in image_processor.__dict__.items():
80
+ self.assertEqual(v, getattr(new_image_processor, k))
81
+
82
+ def test_push_to_hub_via_save_pretrained(self):
83
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
84
+ image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
85
+ # Push to hub via save_pretrained
86
+ with tempfile.TemporaryDirectory() as tmp_dir:
87
+ image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
88
+
89
+ new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
90
+ for k, v in image_processor.__dict__.items():
91
+ self.assertEqual(v, getattr(new_image_processor, k))
92
+
93
+ def test_push_to_hub_in_organization(self):
94
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
95
+ image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
96
+ image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
97
+
98
+ new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
99
+ for k, v in image_processor.__dict__.items():
100
+ self.assertEqual(v, getattr(new_image_processor, k))
101
+
102
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
103
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
104
+ image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
105
+ # Push to hub via save_pretrained
106
+ with tempfile.TemporaryDirectory() as tmp_dir:
107
+ image_processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
108
+
109
+ new_image_processor = ViTImageProcessor.from_pretrained(tmp_repo.repo_id)
110
+ for k, v in image_processor.__dict__.items():
111
+ self.assertEqual(v, getattr(new_image_processor, k))
112
+
113
+ def test_push_to_hub_dynamic_image_processor(self):
114
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
115
+ CustomImageProcessor.register_for_auto_class()
116
+ image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
117
+
118
+ image_processor.push_to_hub(tmp_repo.repo_id, token=self._token)
119
+
120
+ # This has added the proper auto_map field to the config
121
+ self.assertDictEqual(
122
+ image_processor.auto_map,
123
+ {"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
124
+ )
125
+
126
+ new_image_processor = AutoImageProcessor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
127
+ # Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
128
+ self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
129
+
130
+
131
+ class ImageProcessingUtilsTester(unittest.TestCase):
132
+ def test_get_size_dict(self):
133
+ # Test a dict with the wrong keys raises an error
134
+ inputs = {"wrong_key": 224}
135
+ with self.assertRaises(ValueError):
136
+ get_size_dict(inputs)
137
+
138
+ inputs = {"height": 224}
139
+ with self.assertRaises(ValueError):
140
+ get_size_dict(inputs)
141
+
142
+ inputs = {"width": 224, "shortest_edge": 224}
143
+ with self.assertRaises(ValueError):
144
+ get_size_dict(inputs)
145
+
146
+ # Test a dict with the correct keys is returned as is
147
+ inputs = {"height": 224, "width": 224}
148
+ outputs = get_size_dict(inputs)
149
+ self.assertEqual(outputs, inputs)
150
+
151
+ inputs = {"shortest_edge": 224}
152
+ outputs = get_size_dict(inputs)
153
+ self.assertEqual(outputs, {"shortest_edge": 224})
154
+
155
+ inputs = {"longest_edge": 224, "shortest_edge": 224}
156
+ outputs = get_size_dict(inputs)
157
+ self.assertEqual(outputs, {"longest_edge": 224, "shortest_edge": 224})
158
+
159
+ # Test a single int value which represents (size, size)
160
+ outputs = get_size_dict(224)
161
+ self.assertEqual(outputs, {"height": 224, "width": 224})
162
+
163
+ # Test a single int value which represents the shortest edge
164
+ outputs = get_size_dict(224, default_to_square=False)
165
+ self.assertEqual(outputs, {"shortest_edge": 224})
166
+
167
+ # Test a tuple of ints which represents (height, width)
168
+ outputs = get_size_dict((150, 200))
169
+ self.assertEqual(outputs, {"height": 150, "width": 200})
170
+
171
+ # Test a tuple of ints which represents (width, height)
172
+ outputs = get_size_dict((150, 200), height_width_order=False)
173
+ self.assertEqual(outputs, {"height": 200, "width": 150})
174
+
175
+ # Test an int representing the shortest edge and max_size which represents the longest edge
176
+ outputs = get_size_dict(224, max_size=256, default_to_square=False)
177
+ self.assertEqual(outputs, {"shortest_edge": 224, "longest_edge": 256})
178
+
179
+ # Test int with default_to_square=True and max_size fails
180
+ with self.assertRaises(ValueError):
181
+ get_size_dict(224, max_size=256, default_to_square=True)
docs/transformers/tests/utils/test_image_utils.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import codecs
16
+ import os
17
+ import tempfile
18
+ import unittest
19
+ from io import BytesIO
20
+ from typing import Optional
21
+
22
+ import numpy as np
23
+ import pytest
24
+ import requests
25
+ from huggingface_hub.file_download import hf_hub_url, http_get
26
+ from requests import ConnectTimeout, ReadTimeout
27
+
28
+ from tests.pipelines.test_pipelines_document_question_answering import INVOICE_URL
29
+ from transformers import is_torch_available, is_vision_available
30
+ from transformers.image_utils import (
31
+ ChannelDimension,
32
+ get_channel_dimension_axis,
33
+ make_batched_videos,
34
+ make_flat_list_of_images,
35
+ make_list_of_images,
36
+ make_nested_list_of_images,
37
+ )
38
+ from transformers.testing_utils import is_flaky, require_torch, require_vision
39
+
40
+
41
+ if is_torch_available():
42
+ import torch
43
+
44
+ if is_vision_available():
45
+ import PIL.Image
46
+
47
+ from transformers import ImageFeatureExtractionMixin
48
+ from transformers.image_utils import get_image_size, infer_channel_dimension_format, load_image
49
+
50
+
51
+ def get_image_from_hub_dataset(dataset_id: str, filename: str, revision: Optional[str] = None) -> "PIL.Image.Image":
52
+ url = hf_hub_url(dataset_id, filename, repo_type="dataset", revision=revision)
53
+ return PIL.Image.open(BytesIO(requests.get(url).content))
54
+
55
+
56
+ def get_random_image(height, width):
57
+ random_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
58
+ return PIL.Image.fromarray(random_array)
59
+
60
+
61
+ @require_vision
62
+ class ImageFeatureExtractionTester(unittest.TestCase):
63
+ def test_conversion_image_to_array(self):
64
+ feature_extractor = ImageFeatureExtractionMixin()
65
+ image = get_random_image(16, 32)
66
+
67
+ # Conversion with defaults (rescale + channel first)
68
+ array1 = feature_extractor.to_numpy_array(image)
69
+ self.assertTrue(array1.dtype, np.float32)
70
+ self.assertEqual(array1.shape, (3, 16, 32))
71
+
72
+ # Conversion with rescale and not channel first
73
+ array2 = feature_extractor.to_numpy_array(image, channel_first=False)
74
+ self.assertTrue(array2.dtype, np.float32)
75
+ self.assertEqual(array2.shape, (16, 32, 3))
76
+ self.assertTrue(np.array_equal(array1, array2.transpose(2, 0, 1)))
77
+
78
+ # Conversion with no rescale and channel first
79
+ array3 = feature_extractor.to_numpy_array(image, rescale=False)
80
+ self.assertTrue(array3.dtype, np.uint8)
81
+ self.assertEqual(array3.shape, (3, 16, 32))
82
+ self.assertTrue(np.array_equal(array1, array3.astype(np.float32) * (1 / 255.0)))
83
+
84
+ # Conversion with no rescale and not channel first
85
+ array4 = feature_extractor.to_numpy_array(image, rescale=False, channel_first=False)
86
+ self.assertTrue(array4.dtype, np.uint8)
87
+ self.assertEqual(array4.shape, (16, 32, 3))
88
+ self.assertTrue(np.array_equal(array2, array4.astype(np.float32) * (1 / 255.0)))
89
+
90
+ def test_conversion_array_to_array(self):
91
+ feature_extractor = ImageFeatureExtractionMixin()
92
+ array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
93
+
94
+ # By default, rescale (for an array of ints) and channel permute
95
+ array1 = feature_extractor.to_numpy_array(array)
96
+ self.assertTrue(array1.dtype, np.float32)
97
+ self.assertEqual(array1.shape, (3, 16, 32))
98
+ self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
99
+
100
+ # Same with no permute
101
+ array2 = feature_extractor.to_numpy_array(array, channel_first=False)
102
+ self.assertTrue(array2.dtype, np.float32)
103
+ self.assertEqual(array2.shape, (16, 32, 3))
104
+ self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
105
+
106
+ # Force rescale to False
107
+ array3 = feature_extractor.to_numpy_array(array, rescale=False)
108
+ self.assertTrue(array3.dtype, np.uint8)
109
+ self.assertEqual(array3.shape, (3, 16, 32))
110
+ self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
111
+
112
+ # Force rescale to False and no channel permute
113
+ array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
114
+ self.assertTrue(array4.dtype, np.uint8)
115
+ self.assertEqual(array4.shape, (16, 32, 3))
116
+ self.assertTrue(np.array_equal(array4, array))
117
+
118
+ # Now test the default rescale for a float array (defaults to False)
119
+ array5 = feature_extractor.to_numpy_array(array2)
120
+ self.assertTrue(array5.dtype, np.float32)
121
+ self.assertEqual(array5.shape, (3, 16, 32))
122
+ self.assertTrue(np.array_equal(array5, array1))
123
+
124
+ def test_make_list_of_images_pil(self):
125
+ # Test a single image is converted to a list of 1 image
126
+ pil_image = get_random_image(16, 32)
127
+ images_list = make_list_of_images(pil_image)
128
+ self.assertIsInstance(images_list, list)
129
+ self.assertEqual(len(images_list), 1)
130
+ self.assertIsInstance(images_list[0], PIL.Image.Image)
131
+
132
+ # Test a list of images is not modified
133
+ images = [get_random_image(16, 32) for _ in range(4)]
134
+ images_list = make_list_of_images(images)
135
+ self.assertIsInstance(images_list, list)
136
+ self.assertEqual(len(images_list), 4)
137
+ self.assertIsInstance(images_list[0], PIL.Image.Image)
138
+
139
+ def test_make_list_of_images_numpy(self):
140
+ # Test a single image is converted to a list of 1 image
141
+ images = np.random.randint(0, 256, (16, 32, 3))
142
+ images_list = make_list_of_images(images)
143
+ self.assertEqual(len(images_list), 1)
144
+ self.assertTrue(np.array_equal(images_list[0], images))
145
+ self.assertIsInstance(images_list, list)
146
+
147
+ # Test a batch of images is converted to a list of images
148
+ images = np.random.randint(0, 256, (4, 16, 32, 3))
149
+ images_list = make_list_of_images(images)
150
+ self.assertEqual(len(images_list), 4)
151
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
152
+ self.assertIsInstance(images_list, list)
153
+
154
+ # Test a list of images is not modified
155
+ images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
156
+ images_list = make_list_of_images(images)
157
+ self.assertEqual(len(images_list), 4)
158
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
159
+ self.assertIsInstance(images_list, list)
160
+
161
+ # Test batched masks with no channel dimension are converted to a list of masks
162
+ masks = np.random.randint(0, 2, (4, 16, 32))
163
+ masks_list = make_list_of_images(masks, expected_ndims=2)
164
+ self.assertEqual(len(masks_list), 4)
165
+ self.assertTrue(np.array_equal(masks_list[0], masks[0]))
166
+ self.assertIsInstance(masks_list, list)
167
+
168
+ @require_torch
169
+ def test_make_list_of_images_torch(self):
170
+ # Test a single image is converted to a list of 1 image
171
+ images = torch.randint(0, 256, (16, 32, 3))
172
+ images_list = make_list_of_images(images)
173
+ self.assertEqual(len(images_list), 1)
174
+ self.assertTrue(np.array_equal(images_list[0], images))
175
+ self.assertIsInstance(images_list, list)
176
+
177
+ # Test a batch of images is converted to a list of images
178
+ images = torch.randint(0, 256, (4, 16, 32, 3))
179
+ images_list = make_list_of_images(images)
180
+ self.assertEqual(len(images_list), 4)
181
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
182
+ self.assertIsInstance(images_list, list)
183
+
184
+ # Test a list of images is left unchanged
185
+ images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
186
+ images_list = make_list_of_images(images)
187
+ self.assertEqual(len(images_list), 4)
188
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
189
+ self.assertIsInstance(images_list, list)
190
+
191
+ def test_make_flat_list_of_images_pil(self):
192
+ # Test a single image is converted to a list of 1 image
193
+ pil_image = get_random_image(16, 32)
194
+ images_list = make_flat_list_of_images(pil_image)
195
+ self.assertIsInstance(images_list, list)
196
+ self.assertEqual(len(images_list), 1)
197
+ self.assertIsInstance(images_list[0], PIL.Image.Image)
198
+
199
+ # Test a list of images is not modified
200
+ images = [get_random_image(16, 32) for _ in range(4)]
201
+ images_list = make_flat_list_of_images(images)
202
+ self.assertIsInstance(images_list, list)
203
+ self.assertEqual(len(images_list), 4)
204
+ self.assertIsInstance(images_list[0], PIL.Image.Image)
205
+
206
+ # Test a nested list of images is flattened
207
+ images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
208
+ images_list = make_flat_list_of_images(images)
209
+ self.assertIsInstance(images_list, list)
210
+ self.assertEqual(len(images_list), 4)
211
+ self.assertIsInstance(images_list[0], PIL.Image.Image)
212
+
213
+ def test_make_flat_list_of_images_numpy(self):
214
+ # Test a single image is converted to a list of 1 image
215
+ images = np.random.randint(0, 256, (16, 32, 3))
216
+ images_list = make_flat_list_of_images(images)
217
+ self.assertEqual(len(images_list), 1)
218
+ self.assertTrue(np.array_equal(images_list[0], images))
219
+ self.assertIsInstance(images_list, list)
220
+
221
+ # Test a 4d array of images is changed to a list of images
222
+ images = np.random.randint(0, 256, (4, 16, 32, 3))
223
+ images_list = make_flat_list_of_images(images)
224
+ self.assertEqual(len(images_list), 4)
225
+ self.assertIsInstance(images_list, list)
226
+ self.assertIsInstance(images_list[0], np.ndarray)
227
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
228
+
229
+ # Test a list of images is not modified
230
+ images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
231
+ images_list = make_flat_list_of_images(images)
232
+ self.assertEqual(len(images_list), 4)
233
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
234
+ self.assertIsInstance(images_list, list)
235
+
236
+ # Test list of 4d array images is flattened
237
+ images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
238
+ images_list = make_flat_list_of_images(images)
239
+ self.assertEqual(len(images_list), 8)
240
+ self.assertTrue(np.array_equal(images_list[0], images[0][0]))
241
+ self.assertIsInstance(images_list, list)
242
+ self.assertIsInstance(images_list[0], np.ndarray)
243
+
244
+ # Test nested list of images is flattened
245
+ images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
246
+ images_list = make_flat_list_of_images(images)
247
+ self.assertEqual(len(images_list), 4)
248
+ self.assertTrue(np.array_equal(images_list[0], images[0][0]))
249
+ self.assertIsInstance(images_list, list)
250
+
251
+ @require_torch
252
+ def test_make_flat_list_of_images_torch(self):
253
+ # Test a single image is converted to a list of 1 image
254
+ images = torch.randint(0, 256, (16, 32, 3))
255
+ images_list = make_flat_list_of_images(images)
256
+ self.assertEqual(len(images_list), 1)
257
+ self.assertTrue(np.array_equal(images_list[0], images))
258
+ self.assertIsInstance(images_list, list)
259
+
260
+ # Test a 4d tensors of images is changed to a list of images
261
+ images = torch.randint(0, 256, (4, 16, 32, 3))
262
+ images_list = make_flat_list_of_images(images)
263
+ self.assertEqual(len(images_list), 4)
264
+ self.assertIsInstance(images_list, list)
265
+ self.assertIsInstance(images_list[0], torch.Tensor)
266
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
267
+
268
+ # Test a list of images is not modified
269
+ images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
270
+ images_list = make_flat_list_of_images(images)
271
+ self.assertEqual(len(images_list), 4)
272
+ self.assertTrue(np.array_equal(images_list[0], images[0]))
273
+ self.assertIsInstance(images_list, list)
274
+
275
+ # Test list of 4d tensors of imagess is flattened
276
+ images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
277
+ images_list = make_flat_list_of_images(images)
278
+ self.assertEqual(len(images_list), 8)
279
+ self.assertTrue(np.array_equal(images_list[0], images[0][0]))
280
+ self.assertIsInstance(images_list, list)
281
+ self.assertIsInstance(images_list[0], torch.Tensor)
282
+
283
+ # Test nested list of images is flattened
284
+ images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
285
+ images_list = make_flat_list_of_images(images)
286
+ self.assertEqual(len(images_list), 4)
287
+ self.assertTrue(np.array_equal(images_list[0], images[0][0]))
288
+ self.assertIsInstance(images_list, list)
289
+
290
+ def test_make_nested_list_of_images_pil(self):
291
+ # Test a single image is converted to a nested list of 1 image
292
+ pil_image = get_random_image(16, 32)
293
+ images_list = make_nested_list_of_images(pil_image)
294
+ self.assertIsInstance(images_list[0], list)
295
+ self.assertEqual(len(images_list[0]), 1)
296
+ self.assertIsInstance(images_list[0][0], PIL.Image.Image)
297
+
298
+ # Test a list of images is converted to a nested list of images
299
+ images = [get_random_image(16, 32) for _ in range(4)]
300
+ images_list = make_nested_list_of_images(images)
301
+ self.assertIsInstance(images_list[0], list)
302
+ self.assertEqual(len(images_list), 1)
303
+ self.assertEqual(len(images_list[0]), 4)
304
+ self.assertIsInstance(images_list[0][0], PIL.Image.Image)
305
+
306
+ # Test a nested list of images is not modified
307
+ images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
308
+ images_list = make_nested_list_of_images(images)
309
+ self.assertIsInstance(images_list[0], list)
310
+ self.assertEqual(len(images_list), 2)
311
+ self.assertEqual(len(images_list[0]), 2)
312
+ self.assertIsInstance(images_list[0][0], PIL.Image.Image)
313
+
314
+ def test_make_nested_list_of_images_numpy(self):
315
+ # Test a single image is converted to a nested list of 1 image
316
+ images = np.random.randint(0, 256, (16, 32, 3))
317
+ images_list = make_nested_list_of_images(images)
318
+ self.assertIsInstance(images_list[0], list)
319
+ self.assertEqual(len(images_list), 1)
320
+ self.assertTrue(np.array_equal(images_list[0][0], images))
321
+
322
+ # Test a 4d array of images is converted to a nested list of images
323
+ images = np.random.randint(0, 256, (4, 16, 32, 3))
324
+ images_list = make_nested_list_of_images(images)
325
+ self.assertIsInstance(images_list[0], list)
326
+ self.assertIsInstance(images_list[0][0], np.ndarray)
327
+ self.assertEqual(len(images_list), 1)
328
+ self.assertEqual(len(images_list[0]), 4)
329
+ self.assertTrue(np.array_equal(images_list[0][0], images[0]))
330
+
331
+ # Test a list of images is converted to a nested list of images
332
+ images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
333
+ images_list = make_nested_list_of_images(images)
334
+ self.assertIsInstance(images_list[0], list)
335
+ self.assertEqual(len(images_list), 1)
336
+ self.assertEqual(len(images_list[0]), 4)
337
+ self.assertTrue(np.array_equal(images_list[0][0], images[0]))
338
+
339
+ # Test a nested list of images is left unchanged
340
+ images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
341
+ images_list = make_nested_list_of_images(images)
342
+ self.assertIsInstance(images_list[0], list)
343
+ self.assertEqual(len(images_list), 2)
344
+ self.assertEqual(len(images_list[0]), 2)
345
+ self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
346
+
347
+ # Test a list of 4d array images is converted to a nested list of images
348
+ images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
349
+ images_list = make_nested_list_of_images(images)
350
+ self.assertIsInstance(images_list[0], list)
351
+ self.assertIsInstance(images_list[0][0], np.ndarray)
352
+ self.assertEqual(len(images_list), 2)
353
+ self.assertEqual(len(images_list[0]), 4)
354
+ self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
355
+
356
+ @require_torch
357
+ def test_make_nested_list_of_images_torch(self):
358
+ # Test a single image is converted to a nested list of 1 image
359
+ images = torch.randint(0, 256, (16, 32, 3))
360
+ images_list = make_nested_list_of_images(images)
361
+ self.assertIsInstance(images_list[0], list)
362
+ self.assertEqual(len(images_list[0]), 1)
363
+ self.assertTrue(np.array_equal(images_list[0][0], images))
364
+
365
+ # Test a 4d tensor of images is converted to a nested list of images
366
+ images = torch.randint(0, 256, (4, 16, 32, 3))
367
+ images_list = make_nested_list_of_images(images)
368
+ self.assertIsInstance(images_list[0], list)
369
+ self.assertIsInstance(images_list[0][0], torch.Tensor)
370
+ self.assertEqual(len(images_list), 1)
371
+ self.assertEqual(len(images_list[0]), 4)
372
+ self.assertTrue(np.array_equal(images_list[0][0], images[0]))
373
+
374
+ # Test a list of images is converted to a nested list of images
375
+ images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
376
+ images_list = make_nested_list_of_images(images)
377
+ self.assertIsInstance(images_list[0], list)
378
+ self.assertEqual(len(images_list), 1)
379
+ self.assertEqual(len(images_list[0]), 4)
380
+ self.assertTrue(np.array_equal(images_list[0][0], images[0]))
381
+
382
+ # Test a nested list of images is left unchanged
383
+ images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
384
+ images_list = make_nested_list_of_images(images)
385
+ self.assertIsInstance(images_list[0], list)
386
+ self.assertEqual(len(images_list), 2)
387
+ self.assertEqual(len(images_list[0]), 2)
388
+ self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
389
+
390
+ # Test a list of 4d tensor images is converted to a nested list of images
391
+ images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
392
+ images_list = make_nested_list_of_images(images)
393
+ self.assertIsInstance(images_list[0], list)
394
+ self.assertIsInstance(images_list[0][0], torch.Tensor)
395
+ self.assertEqual(len(images_list), 2)
396
+ self.assertEqual(len(images_list[0]), 4)
397
+ self.assertTrue(np.array_equal(images_list[0][0], images[0][0]))
398
+
399
+ def test_make_batched_videos_pil(self):
400
+ # Test a single image is converted to a list of 1 video with 1 frame
401
+ pil_image = get_random_image(16, 32)
402
+ videos_list = make_batched_videos(pil_image)
403
+ self.assertIsInstance(videos_list[0], list)
404
+ self.assertEqual(len(videos_list[0]), 1)
405
+ self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
406
+
407
+ # Test a list of images is converted to a list of 1 video
408
+ images = [get_random_image(16, 32) for _ in range(4)]
409
+ videos_list = make_batched_videos(images)
410
+ self.assertIsInstance(videos_list[0], list)
411
+ self.assertEqual(len(videos_list), 1)
412
+ self.assertEqual(len(videos_list[0]), 4)
413
+ self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
414
+
415
+ # Test a nested list of images is not modified
416
+ images = [[get_random_image(16, 32) for _ in range(2)] for _ in range(2)]
417
+ videos_list = make_nested_list_of_images(images)
418
+ self.assertIsInstance(videos_list[0], list)
419
+ self.assertEqual(len(videos_list), 2)
420
+ self.assertEqual(len(videos_list[0]), 2)
421
+ self.assertIsInstance(videos_list[0][0], PIL.Image.Image)
422
+
423
+ def test_make_batched_videos_numpy(self):
424
+ # Test a single image is converted to a list of 1 video with 1 frame
425
+ images = np.random.randint(0, 256, (16, 32, 3))
426
+ videos_list = make_batched_videos(images)
427
+ self.assertIsInstance(videos_list[0], list)
428
+ self.assertEqual(len(videos_list), 1)
429
+ self.assertTrue(np.array_equal(videos_list[0][0], images))
430
+
431
+ # Test a 4d array of images is converted to a list of 1 video
432
+ images = np.random.randint(0, 256, (4, 16, 32, 3))
433
+ videos_list = make_batched_videos(images)
434
+ self.assertIsInstance(videos_list[0], list)
435
+ self.assertIsInstance(videos_list[0][0], np.ndarray)
436
+ self.assertEqual(len(videos_list), 1)
437
+ self.assertEqual(len(videos_list[0]), 4)
438
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
439
+
440
+ # Test a list of images is converted to a list of videos
441
+ images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
442
+ videos_list = make_batched_videos(images)
443
+ self.assertIsInstance(videos_list[0], list)
444
+ self.assertEqual(len(videos_list), 1)
445
+ self.assertEqual(len(videos_list[0]), 4)
446
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
447
+
448
+ # Test a nested list of images is left unchanged
449
+ images = [[np.random.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
450
+ videos_list = make_batched_videos(images)
451
+ self.assertIsInstance(videos_list[0], list)
452
+ self.assertEqual(len(videos_list), 2)
453
+ self.assertEqual(len(videos_list[0]), 2)
454
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
455
+
456
+ # Test a list of 4d array images is converted to a list of videos
457
+ images = [np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
458
+ videos_list = make_batched_videos(images)
459
+ self.assertIsInstance(videos_list[0], list)
460
+ self.assertIsInstance(videos_list[0][0], np.ndarray)
461
+ self.assertEqual(len(videos_list), 2)
462
+ self.assertEqual(len(videos_list[0]), 4)
463
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
464
+
465
+ # Test a batch of list of 4d array images is converted to a list of videos
466
+ images = [[np.random.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] for _ in range(2)]
467
+ videos_list = make_batched_videos(images)
468
+ self.assertIsInstance(videos_list[0], list)
469
+ self.assertIsInstance(videos_list[0][0], np.ndarray)
470
+ self.assertEqual(len(videos_list), 2)
471
+ self.assertEqual(len(videos_list[0]), 8)
472
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0][0]))
473
+
474
+ @require_torch
475
+ def test_make_batched_videos_torch(self):
476
+ # Test a single image is converted to a list of 1 video with 1 frame
477
+ images = torch.randint(0, 256, (16, 32, 3))
478
+ videos_list = make_batched_videos(images)
479
+ self.assertIsInstance(videos_list[0], list)
480
+ self.assertEqual(len(videos_list[0]), 1)
481
+ self.assertTrue(np.array_equal(videos_list[0][0], images))
482
+
483
+ # Test a 4d tensor of images is converted to a list of 1 video
484
+ images = torch.randint(0, 256, (4, 16, 32, 3))
485
+ videos_list = make_batched_videos(images)
486
+ self.assertIsInstance(videos_list[0], list)
487
+ self.assertIsInstance(videos_list[0][0], torch.Tensor)
488
+ self.assertEqual(len(videos_list), 1)
489
+ self.assertEqual(len(videos_list[0]), 4)
490
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
491
+
492
+ # Test a list of images is converted to a list of videos
493
+ images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
494
+ videos_list = make_batched_videos(images)
495
+ self.assertIsInstance(videos_list[0], list)
496
+ self.assertEqual(len(videos_list), 1)
497
+ self.assertEqual(len(videos_list[0]), 4)
498
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0]))
499
+
500
+ # Test a nested list of images is left unchanged
501
+ images = [[torch.randint(0, 256, (16, 32, 3)) for _ in range(2)] for _ in range(2)]
502
+ videos_list = make_batched_videos(images)
503
+ self.assertIsInstance(videos_list[0], list)
504
+ self.assertEqual(len(videos_list), 2)
505
+ self.assertEqual(len(videos_list[0]), 2)
506
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
507
+
508
+ # Test a list of 4d tensor images is converted to a list of videos
509
+ images = [torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)]
510
+ videos_list = make_batched_videos(images)
511
+ self.assertIsInstance(videos_list[0], list)
512
+ self.assertIsInstance(videos_list[0][0], torch.Tensor)
513
+ self.assertEqual(len(videos_list), 2)
514
+ self.assertEqual(len(videos_list[0]), 4)
515
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0]))
516
+
517
+ # Test a batch of list of 4d tensor images is converted to a list of videos
518
+ images = [[torch.randint(0, 256, (4, 16, 32, 3)) for _ in range(2)] for _ in range(2)]
519
+ videos_list = make_batched_videos(images)
520
+ self.assertIsInstance(videos_list[0], list)
521
+ self.assertIsInstance(videos_list[0][0], torch.Tensor)
522
+ self.assertEqual(len(videos_list), 2)
523
+ self.assertEqual(len(videos_list[0]), 8)
524
+ self.assertTrue(np.array_equal(videos_list[0][0], images[0][0][0]))
525
+
526
+ @require_torch
527
+ def test_conversion_torch_to_array(self):
528
+ feature_extractor = ImageFeatureExtractionMixin()
529
+ tensor = torch.randint(0, 256, (16, 32, 3))
530
+ array = tensor.numpy()
531
+
532
+ # By default, rescale (for a tensor of ints) and channel permute
533
+ array1 = feature_extractor.to_numpy_array(array)
534
+ self.assertTrue(array1.dtype, np.float32)
535
+ self.assertEqual(array1.shape, (3, 16, 32))
536
+ self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)))
537
+
538
+ # Same with no permute
539
+ array2 = feature_extractor.to_numpy_array(array, channel_first=False)
540
+ self.assertTrue(array2.dtype, np.float32)
541
+ self.assertEqual(array2.shape, (16, 32, 3))
542
+ self.assertTrue(np.array_equal(array2, array.astype(np.float32) * (1 / 255.0)))
543
+
544
+ # Force rescale to False
545
+ array3 = feature_extractor.to_numpy_array(array, rescale=False)
546
+ self.assertTrue(array3.dtype, np.uint8)
547
+ self.assertEqual(array3.shape, (3, 16, 32))
548
+ self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
549
+
550
+ # Force rescale to False and no channel permute
551
+ array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
552
+ self.assertTrue(array4.dtype, np.uint8)
553
+ self.assertEqual(array4.shape, (16, 32, 3))
554
+ self.assertTrue(np.array_equal(array4, array))
555
+
556
+ # Now test the default rescale for a float tensor (defaults to False)
557
+ array5 = feature_extractor.to_numpy_array(array2)
558
+ self.assertTrue(array5.dtype, np.float32)
559
+ self.assertEqual(array5.shape, (3, 16, 32))
560
+ self.assertTrue(np.array_equal(array5, array1))
561
+
562
+ def test_conversion_image_to_image(self):
563
+ feature_extractor = ImageFeatureExtractionMixin()
564
+ image = get_random_image(16, 32)
565
+
566
+ # On an image, `to_pil_image1` is a noop.
567
+ image1 = feature_extractor.to_pil_image(image)
568
+ self.assertTrue(isinstance(image, PIL.Image.Image))
569
+ self.assertTrue(np.array_equal(np.array(image), np.array(image1)))
570
+
571
+ def test_conversion_array_to_image(self):
572
+ feature_extractor = ImageFeatureExtractionMixin()
573
+ array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
574
+
575
+ # By default, no rescale (for an array of ints)
576
+ image1 = feature_extractor.to_pil_image(array)
577
+ self.assertTrue(isinstance(image1, PIL.Image.Image))
578
+ self.assertTrue(np.array_equal(np.array(image1), array))
579
+
580
+ # If the array is channel-first, proper reordering of the channels is done.
581
+ image2 = feature_extractor.to_pil_image(array.transpose(2, 0, 1))
582
+ self.assertTrue(isinstance(image2, PIL.Image.Image))
583
+ self.assertTrue(np.array_equal(np.array(image2), array))
584
+
585
+ # If the array has floating type, it's rescaled by default.
586
+ image3 = feature_extractor.to_pil_image(array.astype(np.float32) * (1 / 255.0))
587
+ self.assertTrue(isinstance(image3, PIL.Image.Image))
588
+ self.assertTrue(np.array_equal(np.array(image3), array))
589
+
590
+ # You can override the default to rescale.
591
+ image4 = feature_extractor.to_pil_image(array.astype(np.float32), rescale=False)
592
+ self.assertTrue(isinstance(image4, PIL.Image.Image))
593
+ self.assertTrue(np.array_equal(np.array(image4), array))
594
+
595
+ # And with floats + channel first.
596
+ image5 = feature_extractor.to_pil_image(array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0))
597
+ self.assertTrue(isinstance(image5, PIL.Image.Image))
598
+ self.assertTrue(np.array_equal(np.array(image5), array))
599
+
600
+ @require_torch
601
+ def test_conversion_tensor_to_image(self):
602
+ feature_extractor = ImageFeatureExtractionMixin()
603
+ tensor = torch.randint(0, 256, (16, 32, 3))
604
+ array = tensor.numpy()
605
+
606
+ # By default, no rescale (for a tensor of ints)
607
+ image1 = feature_extractor.to_pil_image(tensor)
608
+ self.assertTrue(isinstance(image1, PIL.Image.Image))
609
+ self.assertTrue(np.array_equal(np.array(image1), array))
610
+
611
+ # If the tensor is channel-first, proper reordering of the channels is done.
612
+ image2 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1))
613
+ self.assertTrue(isinstance(image2, PIL.Image.Image))
614
+ self.assertTrue(np.array_equal(np.array(image2), array))
615
+
616
+ # If the tensor has floating type, it's rescaled by default.
617
+ image3 = feature_extractor.to_pil_image(tensor.float() / 255.0)
618
+ self.assertTrue(isinstance(image3, PIL.Image.Image))
619
+ self.assertTrue(np.array_equal(np.array(image3), array))
620
+
621
+ # You can override the default to rescale.
622
+ image4 = feature_extractor.to_pil_image(tensor.float(), rescale=False)
623
+ self.assertTrue(isinstance(image4, PIL.Image.Image))
624
+ self.assertTrue(np.array_equal(np.array(image4), array))
625
+
626
+ # And with floats + channel first.
627
+ image5 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1).float() * (1 / 255.0))
628
+ self.assertTrue(isinstance(image5, PIL.Image.Image))
629
+ self.assertTrue(np.array_equal(np.array(image5), array))
630
+
631
+ def test_resize_image_and_array(self):
632
+ feature_extractor = ImageFeatureExtractionMixin()
633
+ image = get_random_image(16, 32)
634
+ array = np.array(image)
635
+
636
+ # Size can be an int or a tuple of ints.
637
+ resized_image = feature_extractor.resize(image, 8)
638
+ self.assertTrue(isinstance(resized_image, PIL.Image.Image))
639
+ self.assertEqual(resized_image.size, (8, 8))
640
+
641
+ resized_image1 = feature_extractor.resize(image, (8, 16))
642
+ self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
643
+ self.assertEqual(resized_image1.size, (8, 16))
644
+
645
+ # Passing an array converts it to a PIL Image.
646
+ resized_image2 = feature_extractor.resize(array, 8)
647
+ self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
648
+ self.assertEqual(resized_image2.size, (8, 8))
649
+ self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
650
+
651
+ resized_image3 = feature_extractor.resize(image, (8, 16))
652
+ self.assertTrue(isinstance(resized_image3, PIL.Image.Image))
653
+ self.assertEqual(resized_image3.size, (8, 16))
654
+ self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
655
+
656
+ def test_resize_image_and_array_non_default_to_square(self):
657
+ feature_extractor = ImageFeatureExtractionMixin()
658
+
659
+ heights_widths = [
660
+ # height, width
661
+ # square image
662
+ (28, 28),
663
+ (27, 27),
664
+ # rectangular image: h < w
665
+ (28, 34),
666
+ (29, 35),
667
+ # rectangular image: h > w
668
+ (34, 28),
669
+ (35, 29),
670
+ ]
671
+
672
+ # single integer or single integer in tuple/list
673
+ sizes = [22, 27, 28, 36, [22], (27,)]
674
+
675
+ for (height, width), size in zip(heights_widths, sizes):
676
+ for max_size in (None, 37, 1000):
677
+ image = get_random_image(height, width)
678
+ array = np.array(image)
679
+
680
+ size = size[0] if isinstance(size, (list, tuple)) else size
681
+ # Size can be an int or a tuple of ints.
682
+ # If size is an int, smaller edge of the image will be matched to this number.
683
+ # i.e, if height > width, then image will be rescaled to (size * height / width, size).
684
+ if height < width:
685
+ exp_w, exp_h = (int(size * width / height), size)
686
+ if max_size is not None and max_size < exp_w:
687
+ exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
688
+ elif width < height:
689
+ exp_w, exp_h = (size, int(size * height / width))
690
+ if max_size is not None and max_size < exp_h:
691
+ exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
692
+ else:
693
+ exp_w, exp_h = (size, size)
694
+ if max_size is not None and max_size < size:
695
+ exp_w, exp_h = max_size, max_size
696
+
697
+ resized_image = feature_extractor.resize(image, size=size, default_to_square=False, max_size=max_size)
698
+ self.assertTrue(isinstance(resized_image, PIL.Image.Image))
699
+ self.assertEqual(resized_image.size, (exp_w, exp_h))
700
+
701
+ # Passing an array converts it to a PIL Image.
702
+ resized_image2 = feature_extractor.resize(array, size=size, default_to_square=False, max_size=max_size)
703
+ self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
704
+ self.assertEqual(resized_image2.size, (exp_w, exp_h))
705
+ self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
706
+
707
+ @require_torch
708
+ def test_resize_tensor(self):
709
+ feature_extractor = ImageFeatureExtractionMixin()
710
+ tensor = torch.randint(0, 256, (16, 32, 3))
711
+ array = tensor.numpy()
712
+
713
+ # Size can be an int or a tuple of ints.
714
+ resized_image = feature_extractor.resize(tensor, 8)
715
+ self.assertTrue(isinstance(resized_image, PIL.Image.Image))
716
+ self.assertEqual(resized_image.size, (8, 8))
717
+
718
+ resized_image1 = feature_extractor.resize(tensor, (8, 16))
719
+ self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
720
+ self.assertEqual(resized_image1.size, (8, 16))
721
+
722
+ # Check we get the same results as with NumPy arrays.
723
+ resized_image2 = feature_extractor.resize(array, 8)
724
+ self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
725
+
726
+ resized_image3 = feature_extractor.resize(array, (8, 16))
727
+ self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
728
+
729
+ def test_normalize_image(self):
730
+ feature_extractor = ImageFeatureExtractionMixin()
731
+ image = get_random_image(16, 32)
732
+ array = np.array(image)
733
+ mean = [0.1, 0.5, 0.9]
734
+ std = [0.2, 0.4, 0.6]
735
+
736
+ # PIL Image are converted to NumPy arrays for the normalization
737
+ normalized_image = feature_extractor.normalize(image, mean, std)
738
+ self.assertTrue(isinstance(normalized_image, np.ndarray))
739
+ self.assertEqual(normalized_image.shape, (3, 16, 32))
740
+
741
+ # During the conversion rescale and channel first will be applied.
742
+ expected = array.transpose(2, 0, 1).astype(np.float32) * (1 / 255.0)
743
+ np_mean = np.array(mean).astype(np.float32)[:, None, None]
744
+ np_std = np.array(std).astype(np.float32)[:, None, None]
745
+ expected = (expected - np_mean) / np_std
746
+ self.assertTrue(np.array_equal(normalized_image, expected))
747
+
748
+ def test_normalize_array(self):
749
+ feature_extractor = ImageFeatureExtractionMixin()
750
+ array = np.random.random((16, 32, 3))
751
+ mean = [0.1, 0.5, 0.9]
752
+ std = [0.2, 0.4, 0.6]
753
+
754
+ # mean and std can be passed as lists or NumPy arrays.
755
+ expected = (array - np.array(mean)) / np.array(std)
756
+ normalized_array = feature_extractor.normalize(array, mean, std)
757
+ self.assertTrue(np.array_equal(normalized_array, expected))
758
+
759
+ normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
760
+ self.assertTrue(np.array_equal(normalized_array, expected))
761
+
762
+ # Normalize will detect automatically if channel first or channel last is used.
763
+ array = np.random.random((3, 16, 32))
764
+ expected = (array - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
765
+ normalized_array = feature_extractor.normalize(array, mean, std)
766
+ self.assertTrue(np.array_equal(normalized_array, expected))
767
+
768
+ normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
769
+ self.assertTrue(np.array_equal(normalized_array, expected))
770
+
771
+ @require_torch
772
+ def test_normalize_tensor(self):
773
+ feature_extractor = ImageFeatureExtractionMixin()
774
+ tensor = torch.rand(16, 32, 3)
775
+ mean = [0.1, 0.5, 0.9]
776
+ std = [0.2, 0.4, 0.6]
777
+
778
+ # mean and std can be passed as lists or tensors.
779
+ expected = (tensor - torch.tensor(mean)) / torch.tensor(std)
780
+ normalized_tensor = feature_extractor.normalize(tensor, mean, std)
781
+ self.assertTrue(torch.equal(normalized_tensor, expected))
782
+
783
+ normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
784
+ self.assertTrue(torch.equal(normalized_tensor, expected))
785
+
786
+ # Normalize will detect automatically if channel first or channel last is used.
787
+ tensor = torch.rand(3, 16, 32)
788
+ expected = (tensor - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
789
+ normalized_tensor = feature_extractor.normalize(tensor, mean, std)
790
+ self.assertTrue(torch.equal(normalized_tensor, expected))
791
+
792
+ normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
793
+ self.assertTrue(torch.equal(normalized_tensor, expected))
794
+
795
+ def test_center_crop_image(self):
796
+ feature_extractor = ImageFeatureExtractionMixin()
797
+ image = get_random_image(16, 32)
798
+
799
+ # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
800
+ crop_sizes = [8, (8, 64), 20, (32, 64)]
801
+ for size in crop_sizes:
802
+ cropped_image = feature_extractor.center_crop(image, size)
803
+ self.assertTrue(isinstance(cropped_image, PIL.Image.Image))
804
+
805
+ # PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first).
806
+ expected_size = (size, size) if isinstance(size, int) else (size[1], size[0])
807
+ self.assertEqual(cropped_image.size, expected_size)
808
+
809
+ def test_center_crop_array(self):
810
+ feature_extractor = ImageFeatureExtractionMixin()
811
+ image = get_random_image(16, 32)
812
+ array = feature_extractor.to_numpy_array(image)
813
+
814
+ # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
815
+ crop_sizes = [8, (8, 64), 20, (32, 64)]
816
+ for size in crop_sizes:
817
+ cropped_array = feature_extractor.center_crop(array, size)
818
+ self.assertTrue(isinstance(cropped_array, np.ndarray))
819
+
820
+ expected_size = (size, size) if isinstance(size, int) else size
821
+ self.assertEqual(cropped_array.shape[-2:], expected_size)
822
+
823
+ # Check result is consistent with PIL.Image.crop
824
+ cropped_image = feature_extractor.center_crop(image, size)
825
+ self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image)))
826
+
827
+ @require_torch
828
+ def test_center_crop_tensor(self):
829
+ feature_extractor = ImageFeatureExtractionMixin()
830
+ image = get_random_image(16, 32)
831
+ array = feature_extractor.to_numpy_array(image)
832
+ tensor = torch.tensor(array)
833
+
834
+ # Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
835
+ crop_sizes = [8, (8, 64), 20, (32, 64)]
836
+ for size in crop_sizes:
837
+ cropped_tensor = feature_extractor.center_crop(tensor, size)
838
+ self.assertTrue(isinstance(cropped_tensor, torch.Tensor))
839
+
840
+ expected_size = (size, size) if isinstance(size, int) else size
841
+ self.assertEqual(cropped_tensor.shape[-2:], expected_size)
842
+
843
+ # Check result is consistent with PIL.Image.crop
844
+ cropped_image = feature_extractor.center_crop(image, size)
845
+ self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
846
+
847
+
848
+ @require_vision
849
+ class LoadImageTester(unittest.TestCase):
850
+ def test_load_img_url(self):
851
+ img = load_image(INVOICE_URL)
852
+ img_arr = np.array(img)
853
+
854
+ self.assertEqual(img_arr.shape, (1061, 750, 3))
855
+
856
+ @is_flaky()
857
+ def test_load_img_url_timeout(self):
858
+ with self.assertRaises((ReadTimeout, ConnectTimeout)):
859
+ load_image(INVOICE_URL, timeout=0.001)
860
+
861
+ def test_load_img_local(self):
862
+ img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
863
+ img_arr = np.array(img)
864
+
865
+ self.assertEqual(
866
+ img_arr.shape,
867
+ (480, 640, 3),
868
+ )
869
+
870
+ def test_load_img_base64_prefix(self):
871
+ try:
872
+ tmp_file = tempfile.NamedTemporaryFile(delete=False).name
873
+ with open(tmp_file, "wb") as f:
874
+ http_get(
875
+ "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_0.txt", f
876
+ )
877
+
878
+ with open(tmp_file, encoding="utf-8") as b64:
879
+ img = load_image(b64.read())
880
+ img_arr = np.array(img)
881
+
882
+ finally:
883
+ os.remove(tmp_file)
884
+
885
+ self.assertEqual(img_arr.shape, (64, 32, 3))
886
+
887
+ def test_load_img_base64(self):
888
+ try:
889
+ tmp_file = tempfile.NamedTemporaryFile(delete=False).name
890
+ with open(tmp_file, "wb") as f:
891
+ http_get(
892
+ "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_1.txt", f
893
+ )
894
+
895
+ with open(tmp_file, encoding="utf-8") as b64:
896
+ img = load_image(b64.read())
897
+ img_arr = np.array(img)
898
+
899
+ finally:
900
+ os.remove(tmp_file)
901
+
902
+ self.assertEqual(img_arr.shape, (64, 32, 3))
903
+
904
+ def test_load_img_base64_encoded_bytes(self):
905
+ try:
906
+ tmp_file = tempfile.NamedTemporaryFile(delete=False).name
907
+ with open(tmp_file, "wb") as f:
908
+ http_get(
909
+ "https://huggingface.co/datasets/hf-internal-testing/dummy-base64-images/raw/main/image_2.txt", f
910
+ )
911
+
912
+ with codecs.open(tmp_file, encoding="unicode_escape") as b64:
913
+ img = load_image(b64.read())
914
+ img_arr = np.array(img)
915
+
916
+ finally:
917
+ os.remove(tmp_file)
918
+
919
+ self.assertEqual(img_arr.shape, (256, 256, 3))
920
+
921
+ def test_load_img_rgba(self):
922
+ # we use revision="refs/pr/1" until the PR is merged
923
+ # https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
924
+ img = get_image_from_hub_dataset(
925
+ "hf-internal-testing/fixtures_image_utils", "0-test-lena.png", revision="refs/pr/1"
926
+ )
927
+
928
+ img = load_image(img) # img with mode RGBA
929
+ img_arr = np.array(img)
930
+
931
+ self.assertEqual(
932
+ img_arr.shape,
933
+ (512, 512, 3),
934
+ )
935
+
936
+ def test_load_img_la(self):
937
+ # we use revision="refs/pr/1" until the PR is merged
938
+ # https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
939
+ img = get_image_from_hub_dataset(
940
+ "hf-internal-testing/fixtures_image_utils", "1-test-parrots.png", revision="refs/pr/1"
941
+ )
942
+
943
+ img = load_image(img) # img with mode LA
944
+ img_arr = np.array(img)
945
+
946
+ self.assertEqual(
947
+ img_arr.shape,
948
+ (512, 768, 3),
949
+ )
950
+
951
+ def test_load_img_l(self):
952
+ # we use revision="refs/pr/1" until the PR is merged
953
+ # https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
954
+ img = get_image_from_hub_dataset(
955
+ "hf-internal-testing/fixtures_image_utils", "2-test-tree.png", revision="refs/pr/1"
956
+ )
957
+
958
+ img = load_image(img) # img with mode L
959
+ img_arr = np.array(img)
960
+
961
+ self.assertEqual(
962
+ img_arr.shape,
963
+ (381, 225, 3),
964
+ )
965
+
966
+ def test_load_img_exif_transpose(self):
967
+ # we use revision="refs/pr/1" until the PR is merged
968
+ # https://hf.co/datasets/hf-internal-testing/fixtures_image_utils/discussions/1
969
+
970
+ img_without_exif_transpose = get_image_from_hub_dataset(
971
+ "hf-internal-testing/fixtures_image_utils", "3-test-cat-rotated.jpg", revision="refs/pr/1"
972
+ )
973
+ img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
974
+
975
+ self.assertEqual(
976
+ img_arr_without_exif_transpose.shape,
977
+ (333, 500, 3),
978
+ )
979
+
980
+ img_with_exif_transpose = load_image(img_without_exif_transpose)
981
+ img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
982
+
983
+ self.assertEqual(
984
+ img_arr_with_exif_transpose.shape,
985
+ (500, 333, 3),
986
+ )
987
+
988
+
989
+ class UtilFunctionTester(unittest.TestCase):
990
+ def test_get_image_size(self):
991
+ # Test we can infer the size and channel dimension of an image.
992
+ image = np.random.randint(0, 256, (32, 64, 3))
993
+ self.assertEqual(get_image_size(image), (32, 64))
994
+
995
+ image = np.random.randint(0, 256, (3, 32, 64))
996
+ self.assertEqual(get_image_size(image), (32, 64))
997
+
998
+ # Test the channel dimension can be overridden
999
+ image = np.random.randint(0, 256, (3, 32, 64))
1000
+ self.assertEqual(get_image_size(image, channel_dim=ChannelDimension.LAST), (3, 32))
1001
+
1002
+ def test_infer_channel_dimension(self):
1003
+ # Test we fail with invalid input
1004
+ with pytest.raises(ValueError):
1005
+ infer_channel_dimension_format(np.random.randint(0, 256, (10, 10)))
1006
+
1007
+ with pytest.raises(ValueError):
1008
+ infer_channel_dimension_format(np.random.randint(0, 256, (10, 10, 10, 10, 10)))
1009
+
1010
+ # Test we fail if neither first not last dimension is of size 3 or 1
1011
+ with pytest.raises(ValueError):
1012
+ infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)))
1013
+
1014
+ # But if we explicitly set one of the number of channels to 50 it works
1015
+ inferred_dim = infer_channel_dimension_format(np.random.randint(0, 256, (10, 1, 50)), num_channels=50)
1016
+ self.assertEqual(inferred_dim, ChannelDimension.LAST)
1017
+
1018
+ # Test we correctly identify the channel dimension
1019
+ image = np.random.randint(0, 256, (3, 4, 5))
1020
+ inferred_dim = infer_channel_dimension_format(image)
1021
+ self.assertEqual(inferred_dim, ChannelDimension.FIRST)
1022
+
1023
+ image = np.random.randint(0, 256, (1, 4, 5))
1024
+ inferred_dim = infer_channel_dimension_format(image)
1025
+ self.assertEqual(inferred_dim, ChannelDimension.FIRST)
1026
+
1027
+ image = np.random.randint(0, 256, (4, 5, 3))
1028
+ inferred_dim = infer_channel_dimension_format(image)
1029
+ self.assertEqual(inferred_dim, ChannelDimension.LAST)
1030
+
1031
+ image = np.random.randint(0, 256, (4, 5, 1))
1032
+ inferred_dim = infer_channel_dimension_format(image)
1033
+ self.assertEqual(inferred_dim, ChannelDimension.LAST)
1034
+
1035
+ # We can take a batched array of images and find the dimension
1036
+ image = np.random.randint(0, 256, (1, 3, 4, 5))
1037
+ inferred_dim = infer_channel_dimension_format(image)
1038
+ self.assertEqual(inferred_dim, ChannelDimension.FIRST)
1039
+
1040
+ def test_get_channel_dimension_axis(self):
1041
+ # Test we correctly identify the channel dimension
1042
+ image = np.random.randint(0, 256, (3, 4, 5))
1043
+ inferred_axis = get_channel_dimension_axis(image)
1044
+ self.assertEqual(inferred_axis, 0)
1045
+
1046
+ image = np.random.randint(0, 256, (1, 4, 5))
1047
+ inferred_axis = get_channel_dimension_axis(image)
1048
+ self.assertEqual(inferred_axis, 0)
1049
+
1050
+ image = np.random.randint(0, 256, (4, 5, 3))
1051
+ inferred_axis = get_channel_dimension_axis(image)
1052
+ self.assertEqual(inferred_axis, 2)
1053
+
1054
+ image = np.random.randint(0, 256, (4, 5, 1))
1055
+ inferred_axis = get_channel_dimension_axis(image)
1056
+ self.assertEqual(inferred_axis, 2)
1057
+
1058
+ # We can take a batched array of images and find the dimension
1059
+ image = np.random.randint(0, 256, (1, 3, 4, 5))
1060
+ inferred_axis = get_channel_dimension_axis(image)
1061
+ self.assertEqual(inferred_axis, 1)
docs/transformers/tests/utils/test_import_structure.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import unittest
3
+ from pathlib import Path
4
+
5
+ from transformers.utils.import_utils import define_import_structure, spread_import_structure
6
+
7
+
8
+ import_structures = Path("import_structures")
9
+
10
+
11
+ def fetch__all__(file_content):
12
+ """
13
+ Returns the content of the __all__ variable in the file content.
14
+ Returns None if not defined, otherwise returns a list of strings.
15
+ """
16
+ lines = file_content.split("\n")
17
+ for line_index in range(len(lines)):
18
+ line = lines[line_index]
19
+ if line.startswith("__all__ = "):
20
+ # __all__ is defined on a single line
21
+ if line.endswith("]"):
22
+ return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")]
23
+
24
+ # __all__ is defined on multiple lines
25
+ else:
26
+ _all = []
27
+ for __all__line_index in range(line_index + 1, len(lines)):
28
+ if lines[__all__line_index].strip() == "]":
29
+ return _all
30
+ else:
31
+ _all.append(lines[__all__line_index].strip("\"', "))
32
+
33
+
34
+ class TestImportStructures(unittest.TestCase):
35
+ base_transformers_path = Path(__file__).parent.parent.parent
36
+ models_path = base_transformers_path / "src" / "transformers" / "models"
37
+ models_import_structure = spread_import_structure(define_import_structure(models_path))
38
+
39
+ # TODO: Lysandre
40
+ # See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
41
+ @unittest.skip(reason="failing")
42
+ def test_definition(self):
43
+ import_structure = define_import_structure(import_structures)
44
+ import_structure_definition = {
45
+ frozenset(()): {
46
+ "import_structure_raw_register": {"A0", "a0", "A4"},
47
+ "import_structure_register_with_comments": {"B0", "b0"},
48
+ },
49
+ frozenset(("tf", "torch")): {
50
+ "import_structure_raw_register": {"A1", "a1", "A2", "a2", "A3", "a3"},
51
+ "import_structure_register_with_comments": {"B1", "b1", "B2", "b2", "B3", "b3"},
52
+ },
53
+ frozenset(("torch",)): {
54
+ "import_structure_register_with_duplicates": {"C0", "c0", "C1", "c1", "C2", "c2", "C3", "c3"},
55
+ },
56
+ }
57
+
58
+ self.assertDictEqual(import_structure, import_structure_definition)
59
+
60
+ def test_transformers_specific_model_import(self):
61
+ """
62
+ This test ensures that there is equivalence between what is written down in __all__ and what is
63
+ written down with register().
64
+
65
+ It doesn't test the backends attributed to register().
66
+ """
67
+ for architecture in os.listdir(self.models_path):
68
+ if (
69
+ os.path.isfile(self.models_path / architecture)
70
+ or architecture.startswith("_")
71
+ or architecture == "deprecated"
72
+ ):
73
+ continue
74
+
75
+ with self.subTest(f"Testing arch {architecture}"):
76
+ import_structure = define_import_structure(self.models_path / architecture)
77
+ backend_agnostic_import_structure = {}
78
+ for requirement, module_object_mapping in import_structure.items():
79
+ for module, objects in module_object_mapping.items():
80
+ if module not in backend_agnostic_import_structure:
81
+ backend_agnostic_import_structure[module] = []
82
+
83
+ backend_agnostic_import_structure[module].extend(objects)
84
+
85
+ for module, objects in backend_agnostic_import_structure.items():
86
+ with open(self.models_path / architecture / f"{module}.py") as f:
87
+ content = f.read()
88
+ _all = fetch__all__(content)
89
+
90
+ if _all is None:
91
+ raise ValueError(f"{module} doesn't have __all__ defined.")
92
+
93
+ error_message = (
94
+ f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n"
95
+ f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}"
96
+ )
97
+ self.assertListEqual(sorted(objects), sorted(_all), msg=error_message)
98
+
99
+ # TODO: Lysandre
100
+ # See https://app.circleci.com/pipelines/github/huggingface/transformers/104762/workflows/7ba9c6f7-a3b2-44e6-8eaf-749c7b7261f7/jobs/1393260/tests
101
+ @unittest.skip(reason="failing")
102
+ def test_export_backend_should_be_defined(self):
103
+ with self.assertRaisesRegex(ValueError, "Backend should be defined in the BACKENDS_MAPPING"):
104
+ pass
docs/transformers/tests/utils/test_import_utils.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from transformers.testing_utils import run_test_using_subprocess
4
+ from transformers.utils.import_utils import clear_import_cache
5
+
6
+
7
+ @run_test_using_subprocess
8
+ def test_clear_import_cache():
9
+ """Test the clear_import_cache function."""
10
+
11
+ # Save initial state
12
+ initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
13
+ assert len(initial_modules) > 0, "No transformers modules loaded before test"
14
+
15
+ # Execute clear_import_cache() function
16
+ clear_import_cache()
17
+
18
+ # Verify modules were removed
19
+ remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
20
+ assert len(remaining_modules) < len(initial_modules), "No modules were removed"
21
+
22
+ # Import and verify module exists
23
+ from transformers.models.auto import modeling_auto
24
+
25
+ assert "transformers.models.auto.modeling_auto" in sys.modules
26
+ assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto"
docs/transformers/tests/utils/test_logging.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import unittest
17
+
18
+ from huggingface_hub.utils import are_progress_bars_disabled
19
+
20
+ import transformers.models.bart.tokenization_bart
21
+ from transformers import logging
22
+ from transformers.testing_utils import CaptureLogger, mockenv, mockenv_context
23
+ from transformers.utils.logging import disable_progress_bar, enable_progress_bar
24
+
25
+
26
+ class HfArgumentParserTest(unittest.TestCase):
27
+ def test_set_level(self):
28
+ logger = logging.get_logger()
29
+
30
+ # the current default level is logging.WARNING
31
+ level_origin = logging.get_verbosity()
32
+
33
+ logging.set_verbosity_error()
34
+ self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
35
+
36
+ logging.set_verbosity_warning()
37
+ self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
38
+
39
+ logging.set_verbosity_info()
40
+ self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
41
+
42
+ logging.set_verbosity_debug()
43
+ self.assertEqual(logger.getEffectiveLevel(), logging.get_verbosity())
44
+
45
+ # restore to the original level
46
+ logging.set_verbosity(level_origin)
47
+
48
+ def test_integration(self):
49
+ level_origin = logging.get_verbosity()
50
+
51
+ logger = logging.get_logger("transformers.models.bart.tokenization_bart")
52
+ msg = "Testing 1, 2, 3"
53
+
54
+ # should be able to log warnings (if default settings weren't overridden by `pytest --log-level-all`)
55
+ if level_origin <= logging.WARNING:
56
+ with CaptureLogger(logger) as cl:
57
+ logger.warning(msg)
58
+ self.assertEqual(cl.out, msg + "\n")
59
+
60
+ # this is setting the level for all of `transformers.*` loggers
61
+ logging.set_verbosity_error()
62
+
63
+ # should not be able to log warnings
64
+ with CaptureLogger(logger) as cl:
65
+ logger.warning(msg)
66
+ self.assertEqual(cl.out, "")
67
+
68
+ # should be able to log warnings again
69
+ logging.set_verbosity_warning()
70
+ with CaptureLogger(logger) as cl:
71
+ logger.warning(msg)
72
+ self.assertEqual(cl.out, msg + "\n")
73
+
74
+ # restore to the original level
75
+ logging.set_verbosity(level_origin)
76
+
77
+ @mockenv(TRANSFORMERS_VERBOSITY="error")
78
+ def test_env_override(self):
79
+ # reset for the env var to take effect, next time some logger call is made
80
+ transformers.utils.logging._reset_library_root_logger()
81
+ # this action activates the env var
82
+ _ = logging.get_logger("transformers.models.bart.tokenization_bart")
83
+
84
+ env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None)
85
+ env_level = logging.log_levels[env_level_str]
86
+
87
+ current_level = logging.get_verbosity()
88
+ self.assertEqual(
89
+ env_level,
90
+ current_level,
91
+ f"TRANSFORMERS_VERBOSITY={env_level_str}/{env_level}, but internal verbosity is {current_level}",
92
+ )
93
+
94
+ # restore to the original level
95
+ os.environ["TRANSFORMERS_VERBOSITY"] = ""
96
+ transformers.utils.logging._reset_library_root_logger()
97
+
98
+ @mockenv(TRANSFORMERS_VERBOSITY="super-error")
99
+ def test_env_invalid_override(self):
100
+ # reset for the env var to take effect, next time some logger call is made
101
+ transformers.utils.logging._reset_library_root_logger()
102
+ logger = logging.logging.getLogger()
103
+ with CaptureLogger(logger) as cl:
104
+ # this action activates the env var
105
+ logging.get_logger("transformers.models.bart.tokenization_bart")
106
+ self.assertIn("Unknown option TRANSFORMERS_VERBOSITY=super-error", cl.out)
107
+
108
+ # no need to restore as nothing was changed
109
+
110
+ def test_advisory_warnings(self):
111
+ # testing `logger.warning_advice()`
112
+ transformers.utils.logging._reset_library_root_logger()
113
+
114
+ logger = logging.get_logger("transformers.models.bart.tokenization_bart")
115
+ msg = "Testing 1, 2, 3"
116
+
117
+ with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS="1"):
118
+ # nothing should be logged as env var disables this method
119
+ with CaptureLogger(logger) as cl:
120
+ logger.warning_advice(msg)
121
+ self.assertEqual(cl.out, "")
122
+
123
+ with mockenv_context(TRANSFORMERS_NO_ADVISORY_WARNINGS=""):
124
+ # should log normally as TRANSFORMERS_NO_ADVISORY_WARNINGS is unset
125
+ with CaptureLogger(logger) as cl:
126
+ logger.warning_advice(msg)
127
+ self.assertEqual(cl.out, msg + "\n")
128
+
129
+
130
+ def test_set_progress_bar_enabled():
131
+ disable_progress_bar()
132
+ assert are_progress_bars_disabled()
133
+
134
+ enable_progress_bar()
135
+ assert not are_progress_bars_disabled()
docs/transformers/tests/utils/test_model_card.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import json
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ from transformers.modelcard import ModelCard, TrainingSummary
22
+
23
+
24
+ class ModelCardTester(unittest.TestCase):
25
+ def setUp(self):
26
+ self.inputs_dict = {
27
+ "model_details": {
28
+ "Organization": "testing",
29
+ "Model date": "today",
30
+ "Model version": "v2.1, Developed by Test Corp in 2019.",
31
+ "Architecture": "Convolutional Neural Network.",
32
+ },
33
+ "metrics": "BLEU and ROUGE-1",
34
+ "evaluation_data": {
35
+ "Datasets": {"BLEU": "My-great-dataset-v1", "ROUGE-1": "My-short-dataset-v2.1"},
36
+ "Preprocessing": "See details on https://arxiv.org/pdf/1810.03993.pdf",
37
+ },
38
+ "training_data": {
39
+ "Dataset": "English Wikipedia dump dated 2018-12-01",
40
+ "Preprocessing": (
41
+ "Using SentencePiece vocabulary of size 52k tokens. See details on"
42
+ " https://arxiv.org/pdf/1810.03993.pdf"
43
+ ),
44
+ },
45
+ "quantitative_analyses": {"BLEU": 55.1, "ROUGE-1": 76},
46
+ }
47
+
48
+ def test_model_card_common_properties(self):
49
+ modelcard = ModelCard.from_dict(self.inputs_dict)
50
+ self.assertTrue(hasattr(modelcard, "model_details"))
51
+ self.assertTrue(hasattr(modelcard, "intended_use"))
52
+ self.assertTrue(hasattr(modelcard, "factors"))
53
+ self.assertTrue(hasattr(modelcard, "metrics"))
54
+ self.assertTrue(hasattr(modelcard, "evaluation_data"))
55
+ self.assertTrue(hasattr(modelcard, "training_data"))
56
+ self.assertTrue(hasattr(modelcard, "quantitative_analyses"))
57
+ self.assertTrue(hasattr(modelcard, "ethical_considerations"))
58
+ self.assertTrue(hasattr(modelcard, "caveats_and_recommendations"))
59
+
60
+ def test_model_card_to_json_string(self):
61
+ modelcard = ModelCard.from_dict(self.inputs_dict)
62
+ obj = json.loads(modelcard.to_json_string())
63
+ for key, value in self.inputs_dict.items():
64
+ self.assertEqual(obj[key], value)
65
+
66
+ def test_model_card_to_json_file(self):
67
+ model_card_first = ModelCard.from_dict(self.inputs_dict)
68
+
69
+ with tempfile.TemporaryDirectory() as tmpdirname:
70
+ filename = os.path.join(tmpdirname, "modelcard.json")
71
+ model_card_first.to_json_file(filename)
72
+ model_card_second = ModelCard.from_json_file(filename)
73
+
74
+ self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
75
+
76
+ def test_model_card_from_and_save_pretrained(self):
77
+ model_card_first = ModelCard.from_dict(self.inputs_dict)
78
+
79
+ with tempfile.TemporaryDirectory() as tmpdirname:
80
+ model_card_first.save_pretrained(tmpdirname)
81
+ model_card_second = ModelCard.from_pretrained(tmpdirname)
82
+
83
+ self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict())
84
+
85
+ def test_model_summary_modelcard_base_metadata(self):
86
+ metadata = TrainingSummary("Model name").create_metadata()
87
+ self.assertTrue("library_name" in metadata)
88
+ self.assertTrue(metadata["library_name"] == "transformers")
docs/transformers/tests/utils/test_model_debugging_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import json
18
+ import os
19
+ import tempfile
20
+ import unittest
21
+ from pathlib import Path
22
+
23
+ from transformers import is_torch_available
24
+ from transformers.model_debugging_utils import model_addition_debugger_context
25
+
26
+
27
+ if is_torch_available():
28
+ import torch
29
+ from torch import nn
30
+
31
+ class ToyModel(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.embed = nn.Embedding(10, 4)
35
+ self.linear_1 = nn.Linear(4, 8)
36
+ self.linear_2 = nn.Linear(8, 2)
37
+ self.act = nn.ReLU()
38
+
39
+ def forward(self, input_ids: str):
40
+ hidden_states = self.embed(input_ids).mean(dim=1)
41
+ hidden_states = self.act(self.linear_1(hidden_states))
42
+ return self.linear_2(hidden_states)
43
+
44
+ class TestModelAdditionDebugger(unittest.TestCase):
45
+ def setUp(self):
46
+ self.model = ToyModel()
47
+ self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
48
+
49
+ def tearDown(self):
50
+ gc.collect()
51
+
52
+ def test_debugger_outputs(self):
53
+ with tempfile.TemporaryDirectory() as tmpdir:
54
+ with model_addition_debugger_context(self.model, debug_path=str(tmpdir)):
55
+ _ = self.model.forward(**self.inputs)
56
+
57
+ base = f"{self.model.__class__.__name__}_debug_tree"
58
+ summary = Path(os.path.join(tmpdir, f"{base}_SUMMARY.json"))
59
+ full = Path(os.path.join(tmpdir, f"{base}_FULL_TENSORS.json"))
60
+ self.assertTrue(os.path.isfile(summary) and os.path.isfile(full))
61
+ data = json.loads(summary.read_text())
62
+ self.assertTrue({"module_path", "inputs", "children"} <= data.keys())
63
+ self.assertTrue(data["children"])
64
+
65
+ class ToyLayer(nn.Module):
66
+ def __init__(self, layer_index):
67
+ super().__init__()
68
+ self.layer_index = layer_index
69
+ self.layer_operation = nn.Linear(4, 4)
70
+
71
+ def forward(self, hidden_states):
72
+ return self.layer_operation(hidden_states)
73
+
74
+ class ToyModelWithLayers(nn.Module):
75
+ def __init__(self):
76
+ super().__init__()
77
+ self.input_proj = nn.Linear(4, 4)
78
+ self.layers = nn.ModuleList([ToyLayer(layer_index) for layer_index in range(6)])
79
+ self.output_proj = nn.Linear(4, 2)
80
+
81
+ def forward(self, x):
82
+ x = self.input_proj(x)
83
+ for layer in self.layers:
84
+ x = layer(x)
85
+ return self.output_proj(x)
86
+
87
+ class TestModelWithLayers(unittest.TestCase):
88
+ def setUp(self):
89
+ self.inputs = {"input_ids": torch.randint(0, 10, (1, 3))}
90
+ self.model_with_layers = ToyModelWithLayers()
91
+ self.dense_input = {"x": torch.randn(1, 4)}
92
+
93
+ def tearDown(self):
94
+ gc.collect()
95
+
96
+ def test_layer_pruning_behavior(self):
97
+ # No pruning: expect all 6 layers
98
+ with tempfile.TemporaryDirectory() as tmpdir:
99
+ with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=False):
100
+ _ = self.model_with_layers(**self.dense_input)
101
+
102
+ summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
103
+ with open(summary_path) as f:
104
+ data = json.load(f)
105
+ self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
106
+ for layer_index in range(6):
107
+ self.assertEqual(
108
+ data["children"][layer_index + 1]["module_path"],
109
+ f"ToyModelWithLayers.layers.{int(layer_index)}",
110
+ )
111
+
112
+ # Pruning: expect only 2 layers (0 and 5)
113
+ with tempfile.TemporaryDirectory() as tmpdir:
114
+ with model_addition_debugger_context(self.model_with_layers, debug_path=tmpdir, do_prune_layers=True):
115
+ _ = self.model_with_layers(**self.dense_input)
116
+
117
+ summary_path = os.path.join(tmpdir, "ToyModelWithLayers_debug_tree_SUMMARY.json")
118
+ with open(summary_path) as f:
119
+ data = json.load(f)
120
+ self.assertEqual(set(data.keys()), {"module_path", "inputs", "children"})
121
+ self.assertEqual(data["children"][1]["module_path"], "ToyModelWithLayers.layers.0")
122
+ self.assertEqual(data["children"][2]["module_path"], "ToyModelWithLayers.layers.5")
docs/transformers/tests/utils/test_model_output.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The Hugging Face Team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import io
16
+ import unittest
17
+ from dataclasses import dataclass
18
+ from typing import Optional
19
+
20
+ from transformers import AlbertForMaskedLM
21
+ from transformers.testing_utils import require_torch
22
+ from transformers.utils import ModelOutput, is_torch_available
23
+
24
+
25
+ if is_torch_available():
26
+ import torch
27
+
28
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
29
+
30
+
31
+ @dataclass
32
+ class ModelOutputTest(ModelOutput):
33
+ a: float
34
+ b: Optional[float] = None
35
+ c: Optional[float] = None
36
+
37
+
38
+ class ModelOutputTester(unittest.TestCase):
39
+ def test_get_attributes(self):
40
+ x = ModelOutputTest(a=30)
41
+ self.assertEqual(x.a, 30)
42
+ self.assertIsNone(x.b)
43
+ self.assertIsNone(x.c)
44
+ with self.assertRaises(AttributeError):
45
+ _ = x.d
46
+
47
+ def test_index_with_ints_and_slices(self):
48
+ x = ModelOutputTest(a=30, b=10)
49
+ self.assertEqual(x[0], 30)
50
+ self.assertEqual(x[1], 10)
51
+ self.assertEqual(x[:2], (30, 10))
52
+ self.assertEqual(x[:], (30, 10))
53
+
54
+ x = ModelOutputTest(a=30, c=10)
55
+ self.assertEqual(x[0], 30)
56
+ self.assertEqual(x[1], 10)
57
+ self.assertEqual(x[:2], (30, 10))
58
+ self.assertEqual(x[:], (30, 10))
59
+
60
+ def test_index_with_strings(self):
61
+ x = ModelOutputTest(a=30, b=10)
62
+ self.assertEqual(x["a"], 30)
63
+ self.assertEqual(x["b"], 10)
64
+ with self.assertRaises(KeyError):
65
+ _ = x["c"]
66
+
67
+ x = ModelOutputTest(a=30, c=10)
68
+ self.assertEqual(x["a"], 30)
69
+ self.assertEqual(x["c"], 10)
70
+ with self.assertRaises(KeyError):
71
+ _ = x["b"]
72
+
73
+ def test_dict_like_properties(self):
74
+ x = ModelOutputTest(a=30)
75
+ self.assertEqual(list(x.keys()), ["a"])
76
+ self.assertEqual(list(x.values()), [30])
77
+ self.assertEqual(list(x.items()), [("a", 30)])
78
+ self.assertEqual(list(x), ["a"])
79
+
80
+ x = ModelOutputTest(a=30, b=10)
81
+ self.assertEqual(list(x.keys()), ["a", "b"])
82
+ self.assertEqual(list(x.values()), [30, 10])
83
+ self.assertEqual(list(x.items()), [("a", 30), ("b", 10)])
84
+ self.assertEqual(list(x), ["a", "b"])
85
+
86
+ x = ModelOutputTest(a=30, c=10)
87
+ self.assertEqual(list(x.keys()), ["a", "c"])
88
+ self.assertEqual(list(x.values()), [30, 10])
89
+ self.assertEqual(list(x.items()), [("a", 30), ("c", 10)])
90
+ self.assertEqual(list(x), ["a", "c"])
91
+
92
+ with self.assertRaises(Exception):
93
+ x = x.update({"d": 20})
94
+ with self.assertRaises(Exception):
95
+ del x["a"]
96
+ with self.assertRaises(Exception):
97
+ _ = x.pop("a")
98
+ with self.assertRaises(Exception):
99
+ _ = x.setdefault("d", 32)
100
+
101
+ def test_set_attributes(self):
102
+ x = ModelOutputTest(a=30)
103
+ x.a = 10
104
+ self.assertEqual(x.a, 10)
105
+ self.assertEqual(x["a"], 10)
106
+
107
+ def test_set_keys(self):
108
+ x = ModelOutputTest(a=30)
109
+ x["a"] = 10
110
+ self.assertEqual(x.a, 10)
111
+ self.assertEqual(x["a"], 10)
112
+
113
+ def test_instantiate_from_dict(self):
114
+ x = ModelOutputTest({"a": 30, "b": 10})
115
+ self.assertEqual(list(x.keys()), ["a", "b"])
116
+ self.assertEqual(x.a, 30)
117
+ self.assertEqual(x.b, 10)
118
+
119
+ def test_instantiate_from_iterator(self):
120
+ x = ModelOutputTest([("a", 30), ("b", 10)])
121
+ self.assertEqual(list(x.keys()), ["a", "b"])
122
+ self.assertEqual(x.a, 30)
123
+ self.assertEqual(x.b, 10)
124
+
125
+ with self.assertRaises(ValueError):
126
+ _ = ModelOutputTest([("a", 30), (10, 10)])
127
+
128
+ x = ModelOutputTest(a=(30, 30))
129
+ self.assertEqual(list(x.keys()), ["a"])
130
+ self.assertEqual(x.a, (30, 30))
131
+
132
+ @require_torch
133
+ def test_torch_pytree(self):
134
+ # ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
135
+ # this is important for DistributedDataParallel gradient synchronization with static_graph=True
136
+ import torch.utils._pytree as pytree
137
+
138
+ x = ModelOutput({"a": 1.0, "c": 2.0})
139
+ self.assertFalse(pytree._is_leaf(x))
140
+
141
+ x = ModelOutputTest(a=1.0, c=2.0)
142
+ self.assertFalse(pytree._is_leaf(x))
143
+
144
+ expected_flat_outs = [1.0, 2.0]
145
+ expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])
146
+
147
+ actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
148
+ self.assertEqual(expected_flat_outs, actual_flat_outs)
149
+ self.assertEqual(expected_tree_spec, actual_tree_spec)
150
+
151
+ unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
152
+ self.assertEqual(x, unflattened_x)
153
+
154
+ if is_torch_greater_or_equal_than_2_2:
155
+ self.assertEqual(
156
+ pytree.treespec_dumps(actual_tree_spec),
157
+ '[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": "[\\"a\\", \\"c\\"]", "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
158
+ )
159
+
160
+ # TODO: @ydshieh
161
+ @unittest.skip(reason="CPU OOM")
162
+ @require_torch
163
+ def test_export_serialization(self):
164
+ if not is_torch_greater_or_equal_than_2_2:
165
+ self.skipTest(reason="Export serialization requires torch >= 2.2.0")
166
+
167
+ model_cls = AlbertForMaskedLM
168
+ model_config = model_cls.config_class()
169
+ model = model_cls(model_config)
170
+
171
+ input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
172
+
173
+ ep = torch.export.export(model, (), input_dict)
174
+
175
+ buffer = io.BytesIO()
176
+ torch.export.save(ep, buffer)
177
+ buffer.seek(0)
178
+ loaded_ep = torch.export.load(buffer)
179
+
180
+ input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
181
+ assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)
182
+
183
+
184
+ class ModelOutputTestNoDataclass(ModelOutput):
185
+ """Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
186
+
187
+ a: float
188
+ b: Optional[float] = None
189
+ c: Optional[float] = None
190
+
191
+
192
+ class ModelOutputSubclassTester(unittest.TestCase):
193
+ def test_direct_model_output(self):
194
+ # Check that direct usage of ModelOutput instantiates without errors
195
+ ModelOutput({"a": 1.1})
196
+
197
+ def test_subclass_no_dataclass(self):
198
+ # Check that a subclass of ModelOutput without @dataclass is invalid
199
+ # A valid subclass is inherently tested other unit tests above.
200
+ with self.assertRaises(TypeError):
201
+ ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
docs/transformers/tests/utils/test_modeling_flax_utils.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tempfile
16
+ import unittest
17
+
18
+ import numpy as np
19
+ from huggingface_hub import HfFolder, snapshot_download
20
+
21
+ from transformers import BertConfig, is_flax_available
22
+ from transformers.testing_utils import (
23
+ TOKEN,
24
+ CaptureLogger,
25
+ TemporaryHubRepo,
26
+ is_staging_test,
27
+ require_flax,
28
+ require_safetensors,
29
+ )
30
+ from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
31
+
32
+
33
+ if is_flax_available():
34
+ import os
35
+
36
+ from flax.core.frozen_dict import unfreeze
37
+ from flax.traverse_util import flatten_dict
38
+
39
+ from transformers import FlaxBertModel
40
+
41
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
42
+
43
+
44
+ @require_flax
45
+ @is_staging_test
46
+ class FlaxModelPushToHubTester(unittest.TestCase):
47
+ @classmethod
48
+ def setUpClass(cls):
49
+ cls._token = TOKEN
50
+ HfFolder.save_token(TOKEN)
51
+
52
+ def test_push_to_hub(self):
53
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
54
+ config = BertConfig(
55
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
56
+ )
57
+ model = FlaxBertModel(config)
58
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
59
+
60
+ new_model = FlaxBertModel.from_pretrained(tmp_repo.repo_id)
61
+
62
+ base_params = flatten_dict(unfreeze(model.params))
63
+ new_params = flatten_dict(unfreeze(new_model.params))
64
+
65
+ for key in base_params.keys():
66
+ max_diff = (base_params[key] - new_params[key]).sum().item()
67
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
68
+
69
+ def test_push_to_hub_via_save_pretrained(self):
70
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
71
+ config = BertConfig(
72
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
73
+ )
74
+ model = FlaxBertModel(config)
75
+ # Push to hub via save_pretrained
76
+ with tempfile.TemporaryDirectory() as tmp_dir:
77
+ model.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
78
+
79
+ new_model = FlaxBertModel.from_pretrained(tmp_repo.repo_id)
80
+
81
+ base_params = flatten_dict(unfreeze(model.params))
82
+ new_params = flatten_dict(unfreeze(new_model.params))
83
+
84
+ for key in base_params.keys():
85
+ max_diff = (base_params[key] - new_params[key]).sum().item()
86
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
87
+
88
+ def test_push_to_hub_in_organization(self):
89
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
90
+ config = BertConfig(
91
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
92
+ )
93
+ model = FlaxBertModel(config)
94
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
95
+
96
+ new_model = FlaxBertModel.from_pretrained(tmp_repo.repo_id)
97
+
98
+ base_params = flatten_dict(unfreeze(model.params))
99
+ new_params = flatten_dict(unfreeze(new_model.params))
100
+
101
+ for key in base_params.keys():
102
+ max_diff = (base_params[key] - new_params[key]).sum().item()
103
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
104
+
105
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
106
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
107
+ config = BertConfig(
108
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
109
+ )
110
+ model = FlaxBertModel(config)
111
+ # Push to hub via save_pretrained
112
+ with tempfile.TemporaryDirectory() as tmp_dir:
113
+ model.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
114
+
115
+ new_model = FlaxBertModel.from_pretrained(tmp_repo.repo_id)
116
+
117
+ base_params = flatten_dict(unfreeze(model.params))
118
+ new_params = flatten_dict(unfreeze(new_model.params))
119
+
120
+ for key in base_params.keys():
121
+ max_diff = (base_params[key] - new_params[key]).sum().item()
122
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
123
+
124
+
125
+ def check_models_equal(model1, model2):
126
+ models_are_equal = True
127
+ flat_params_1 = flatten_dict(model1.params)
128
+ flat_params_2 = flatten_dict(model2.params)
129
+ for key in flat_params_1.keys():
130
+ if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
131
+ models_are_equal = False
132
+
133
+ return models_are_equal
134
+
135
+
136
+ @require_flax
137
+ class FlaxModelUtilsTest(unittest.TestCase):
138
+ def test_model_from_pretrained_subfolder(self):
139
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
140
+ model = FlaxBertModel(config)
141
+
142
+ subfolder = "bert"
143
+ with tempfile.TemporaryDirectory() as tmp_dir:
144
+ model.save_pretrained(os.path.join(tmp_dir, subfolder))
145
+
146
+ with self.assertRaises(OSError):
147
+ _ = FlaxBertModel.from_pretrained(tmp_dir)
148
+
149
+ model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
150
+
151
+ self.assertTrue(check_models_equal(model, model_loaded))
152
+
153
+ def test_model_from_pretrained_subfolder_sharded(self):
154
+ config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
155
+ model = FlaxBertModel(config)
156
+
157
+ subfolder = "bert"
158
+ with tempfile.TemporaryDirectory() as tmp_dir:
159
+ model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
160
+
161
+ with self.assertRaises(OSError):
162
+ _ = FlaxBertModel.from_pretrained(tmp_dir)
163
+
164
+ model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
165
+
166
+ self.assertTrue(check_models_equal(model, model_loaded))
167
+
168
+ def test_model_from_pretrained_hub_subfolder(self):
169
+ subfolder = "bert"
170
+ model_id = "hf-internal-testing/tiny-random-bert-subfolder"
171
+
172
+ with self.assertRaises(OSError):
173
+ _ = FlaxBertModel.from_pretrained(model_id)
174
+
175
+ model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
176
+
177
+ self.assertIsNotNone(model)
178
+
179
+ def test_model_from_pretrained_hub_subfolder_sharded(self):
180
+ subfolder = "bert"
181
+ model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
182
+ with self.assertRaises(OSError):
183
+ _ = FlaxBertModel.from_pretrained(model_id)
184
+
185
+ model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
186
+
187
+ self.assertIsNotNone(model)
188
+
189
+ @require_safetensors
190
+ def test_safetensors_save_and_load(self):
191
+ model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
192
+ with tempfile.TemporaryDirectory() as tmp_dir:
193
+ model.save_pretrained(tmp_dir, safe_serialization=True)
194
+
195
+ # No msgpack file, only a model.safetensors
196
+ self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
197
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
198
+
199
+ new_model = FlaxBertModel.from_pretrained(tmp_dir)
200
+
201
+ self.assertTrue(check_models_equal(model, new_model))
202
+
203
+ @require_safetensors
204
+ def test_safetensors_load_from_hub(self):
205
+ """
206
+ This test checks that we can load safetensors from a checkpoint that only has those on the Hub
207
+ """
208
+ flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
209
+
210
+ # Can load from the Flax-formatted checkpoint
211
+ safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
212
+ self.assertTrue(check_models_equal(flax_model, safetensors_model))
213
+
214
+ @require_safetensors
215
+ def test_safetensors_load_from_local(self):
216
+ """
217
+ This test checks that we can load safetensors from a checkpoint that only has those on the Hub
218
+ """
219
+ with tempfile.TemporaryDirectory() as tmp:
220
+ location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
221
+ flax_model = FlaxBertModel.from_pretrained(location)
222
+
223
+ with tempfile.TemporaryDirectory() as tmp:
224
+ location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
225
+ safetensors_model = FlaxBertModel.from_pretrained(location)
226
+
227
+ self.assertTrue(check_models_equal(flax_model, safetensors_model))
228
+
229
+ @require_safetensors
230
+ def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
231
+ """
232
+ This test checks that we'll first download msgpack weights before safetensors
233
+ The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
234
+ """
235
+ FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
236
+
237
+ @require_safetensors
238
+ def test_safetensors_load_from_local_msgpack_before_safetensors(self):
239
+ """
240
+ This test checks that we'll first download msgpack weights before safetensors
241
+ The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
242
+ """
243
+ with tempfile.TemporaryDirectory() as tmp:
244
+ location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
245
+ FlaxBertModel.from_pretrained(location)
246
+
247
+ @require_safetensors
248
+ def test_safetensors_flax_from_flax(self):
249
+ model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
250
+
251
+ with tempfile.TemporaryDirectory() as tmp_dir:
252
+ model.save_pretrained(tmp_dir, safe_serialization=True)
253
+ new_model = FlaxBertModel.from_pretrained(tmp_dir)
254
+
255
+ self.assertTrue(check_models_equal(model, new_model))
256
+
257
+ @require_safetensors
258
+ def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
259
+ with tempfile.TemporaryDirectory() as tmp_dir:
260
+ path = snapshot_download(
261
+ "hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir
262
+ )
263
+
264
+ # This should not raise even if there are two types of sharded weights
265
+ FlaxBertModel.from_pretrained(path)
266
+
267
+ @require_safetensors
268
+ def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
269
+ # This should not raise even if there are two types of sharded weights
270
+ # This should discard the safetensors weights in favor of the msgpack sharded weights
271
+ FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
272
+
273
+ @require_safetensors
274
+ def test_safetensors_from_pt_bf16(self):
275
+ # This should not raise; should be able to load bf16-serialized torch safetensors without issue
276
+ # and without torch.
277
+ logger = logging.get_logger("transformers.modeling_flax_utils")
278
+
279
+ with CaptureLogger(logger) as cl:
280
+ FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
281
+
282
+ self.assertTrue(
283
+ "Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
284
+ in cl.out
285
+ )
docs/transformers/tests/utils/test_modeling_rope_utils.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import math
17
+ import unittest
18
+
19
+ from transformers import LlamaConfig
20
+ from transformers.testing_utils import is_torch_available, require_torch, torch_device
21
+
22
+
23
+ if is_torch_available():
24
+ import torch
25
+
26
+ from transformers import ROPE_INIT_FUNCTIONS
27
+ from transformers.modeling_rope_utils import rope_config_validation
28
+
29
+
30
+ @require_torch
31
+ class RopeTest(unittest.TestCase):
32
+ def test_rope_validation(self):
33
+ config = LlamaConfig()
34
+ all_rope_types = ROPE_INIT_FUNCTIONS.keys()
35
+
36
+ # The base config is always valid (default RoPE)
37
+ rope_config_validation(config)
38
+
39
+ # If we explicitly set the other RoPE types, then validation should fail
40
+ for rope_type in all_rope_types:
41
+ if rope_type != "default":
42
+ config.rope_scaling = {"rope_type": rope_type}
43
+ with self.assertRaises(KeyError):
44
+ rope_config_validation(config)
45
+
46
+ # Parameters are exclusive to their own RoPE type, and should raise an exception if incorrectly passed
47
+ valid_param_mapping = {
48
+ "factor": ["linear", "dynamic", "yarn", "longrope"],
49
+ "attention_factor": ["yarn", "longrope"],
50
+ "beta_fast": ["yarn"],
51
+ "beta_slow": ["yarn"],
52
+ "short_factor": ["longrope"],
53
+ "long_factor": ["longrope"],
54
+ }
55
+ for rope_type in all_rope_types:
56
+ if rope_type == "default":
57
+ continue # checked above
58
+ for param, valid_rope_types in valid_param_mapping.items():
59
+ # Set `param` with a dummy value -- we want to test the dict key
60
+ config.rope_scaling = {"rope_type": rope_type, param: True}
61
+ if rope_type in valid_rope_types:
62
+ continue
63
+ else:
64
+ with self.assertRaises(KeyError):
65
+ rope_config_validation(config)
66
+
67
+ # Any other parameters passed to RoPE will raise a warning that a particular key is not used
68
+ # But sometimes we can have model-specific RoPE kwargs and bypass warning with `ignore_keys`
69
+ model_specific_kwarg = "mrope_sections" # e,g in Qwen2-VL
70
+
71
+ for rope_type in all_rope_types:
72
+ if rope_type == "default":
73
+ config.rope_scaling = {"rope_type": rope_type, model_specific_kwarg: True}
74
+ rope_config_validation(config, ignore_keys={model_specific_kwarg})
75
+ with self.assertLogs("transformers.modeling_rope_utils", level="WARNING") as logs:
76
+ rope_config_validation(config)
77
+ self.assertEqual(len(logs.output), 1)
78
+ self.assertIn(model_specific_kwarg, logs.output[0])
79
+
80
+ def test_default_rope_function_bc(self):
81
+ config = LlamaConfig()
82
+ device = torch_device
83
+
84
+ rope_kwargs = {
85
+ "rope_type": "default",
86
+ "dim": config.hidden_size // config.num_attention_heads,
87
+ "max_position_embeddings": config.max_position_embeddings,
88
+ "base": config.rope_theta,
89
+ }
90
+
91
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
92
+ config_freqs = rope_fn(config=config, device=device)[0]
93
+ kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
94
+ torch.testing.assert_close(config_freqs, kwargs_freqs)
95
+
96
+ def test_linear_rope_function_bc(self):
97
+ config = LlamaConfig()
98
+ config.rope_scaling = {"rope_type": "linear", "factor": 10.0}
99
+ device = torch_device
100
+
101
+ rope_kwargs = {
102
+ "rope_type": "linear",
103
+ "dim": config.hidden_size // config.num_attention_heads,
104
+ "max_position_embeddings": config.max_position_embeddings,
105
+ "base": config.rope_theta,
106
+ "factor": 10.0,
107
+ }
108
+
109
+ rope_fn = ROPE_INIT_FUNCTIONS["linear"]
110
+ config_freqs = rope_fn(config=config, device=device)[0]
111
+ kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
112
+ torch.testing.assert_close(config_freqs, kwargs_freqs)
113
+
114
+ def test_dynamic_rope_function_bc(self):
115
+ config = LlamaConfig()
116
+ config.rope_scaling = {"rope_type": "dynamic", "factor": 10.0}
117
+ device = torch_device
118
+
119
+ rope_kwargs = {
120
+ "rope_type": "dynamic",
121
+ "dim": config.hidden_size // config.num_attention_heads,
122
+ "max_position_embeddings": config.max_position_embeddings,
123
+ "base": config.rope_theta,
124
+ "factor": 10.0,
125
+ }
126
+
127
+ rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
128
+ config_freqs = rope_fn(config=config, device=device)[0]
129
+ kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
130
+ torch.testing.assert_close(config_freqs, kwargs_freqs)
131
+
132
+ def test_default_rope_numerically(self):
133
+ # Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
134
+ # multiple RoPE strategies will fail.
135
+ # fmt: off
136
+ EXPECTED_INV_FREQ = torch.tensor(
137
+ [
138
+ 1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
139
+ 4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
140
+ 1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
141
+ 7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
142
+ 3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
143
+ 1.3335e-02, 1.1548e-02, 1.0000e-02, 8.6596e-03, 7.4989e-03, 6.4938e-03,
144
+ 5.6234e-03, 4.8697e-03, 4.2170e-03, 3.6517e-03, 3.1623e-03, 2.7384e-03,
145
+ 2.3714e-03, 2.0535e-03, 1.7783e-03, 1.5399e-03, 1.3335e-03, 1.1548e-03,
146
+ 1.0000e-03, 8.6596e-04, 7.4989e-04, 6.4938e-04, 5.6234e-04, 4.8697e-04,
147
+ 4.2170e-04, 3.6517e-04, 3.1623e-04, 2.7384e-04, 2.3714e-04, 2.0535e-04,
148
+ 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04
149
+ ], device=torch_device
150
+ )
151
+ # fmt: on
152
+
153
+ # input sanity checks: if these change, the output will also change
154
+ config = LlamaConfig()
155
+ self.assertEqual(config.rope_scaling, None)
156
+ self.assertEqual(config.hidden_size, 4096)
157
+ self.assertEqual(config.num_attention_heads, 32)
158
+ self.assertEqual(config.rope_theta, 10000.0)
159
+ self.assertFalse(hasattr(config, "partial_rotary_factor"))
160
+
161
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
162
+ inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
163
+
164
+ self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for default RoPE
165
+ torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
166
+
167
+ def test_linear_rope_numerically(self):
168
+ # This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
169
+ # frequencies (= the inverse frequencies are scaled **inversely**)
170
+ config = LlamaConfig()
171
+ default_rope_fn = ROPE_INIT_FUNCTIONS["default"]
172
+ default_inv_freq, _ = default_rope_fn(config=config, device=torch_device)
173
+
174
+ rope_fn = ROPE_INIT_FUNCTIONS["linear"]
175
+ for factor in (2.0, 10.0, 20.0):
176
+ config.rope_scaling = {"rope_type": "linear", "factor": factor}
177
+ inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
178
+ self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for linear RoPE
179
+ torch.testing.assert_close(inv_freq, default_inv_freq / factor)
180
+
181
+ def test_dynamic_rope_numerically(self):
182
+ # fmt: off
183
+ EXPECTED_INV_FREQ = torch.tensor(
184
+ [
185
+ 1.0000e+00, 8.0931e-01, 6.5498e-01, 5.3008e-01, 4.2900e-01, 3.4720e-01,
186
+ 2.8099e-01, 2.2741e-01, 1.8404e-01, 1.4895e-01, 1.2055e-01, 9.7558e-02,
187
+ 7.8955e-02, 6.3899e-02, 5.1714e-02, 4.1853e-02, 3.3872e-02, 2.7413e-02,
188
+ 2.2185e-02, 1.7955e-02, 1.4531e-02, 1.1760e-02, 9.5176e-03, 7.7027e-03,
189
+ 6.2339e-03, 5.0451e-03, 4.0831e-03, 3.3045e-03, 2.6744e-03, 2.1644e-03,
190
+ 1.7517e-03, 1.4176e-03, 1.1473e-03, 9.2852e-04, 7.5146e-04, 6.0817e-04,
191
+ 4.9220e-04, 3.9834e-04, 3.2238e-04, 2.6091e-04, 2.1115e-04, 1.7089e-04,
192
+ 1.3830e-04, 1.1193e-04, 9.0585e-05, 7.3312e-05, 5.9332e-05, 4.8018e-05,
193
+ 3.8861e-05, 3.1451e-05, 2.5453e-05, 2.0600e-05, 1.6672e-05, 1.3492e-05,
194
+ 1.0920e-05, 8.8374e-06, 7.1522e-06, 5.7883e-06, 4.6845e-06, 3.7912e-06,
195
+ 3.0683e-06, 2.4832e-06, 2.0097e-06, 1.6265e-06
196
+ ], device=torch_device
197
+ )
198
+ # fmt: on
199
+
200
+ # input sanity checks: if these change, the output will also change
201
+ config = LlamaConfig()
202
+ self.assertEqual(config.rope_scaling, None)
203
+ self.assertEqual(config.hidden_size, 4096)
204
+ self.assertEqual(config.num_attention_heads, 32)
205
+ self.assertEqual(config.rope_theta, 10000.0)
206
+ self.assertFalse(hasattr(config, "partial_rotary_factor"))
207
+
208
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
209
+ default_inv_freq, _ = rope_fn(config=config, device=torch_device)
210
+
211
+ # Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
212
+ # model's original training sequence length
213
+ rope_fn = ROPE_INIT_FUNCTIONS["dynamic"]
214
+ for factor in (2.0, 10.0, 20.0):
215
+ config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
216
+ inv_freq, attention_scale = rope_fn(config=config, device=torch_device)
217
+ self.assertEqual(attention_scale, 1.0) # attention scale is always 1 for dynamic RoPE
218
+ torch.testing.assert_close(inv_freq, default_inv_freq)
219
+
220
+ inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=1)
221
+ torch.testing.assert_close(inv_freq, default_inv_freq)
222
+
223
+ # Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
224
+ # will scale up (i.e., the inverse frequencies will scale down).
225
+ factor = 10.0
226
+ config.rope_scaling = {"rope_type": "dynamic", "factor": factor}
227
+ inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=16384)
228
+ with self.assertRaises(AssertionError): # It is NOT a linear factor
229
+ torch.testing.assert_close(inv_freq, default_inv_freq / factor)
230
+ torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
231
+
232
+ def test_yarn_rope_numerically(self):
233
+ # fmt: off
234
+ EXPECTED_INV_FREQ = torch.tensor(
235
+ [
236
+ 1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
237
+ 4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
238
+ 1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.3479e-02,
239
+ 6.9590e-02, 5.7925e-02, 4.8136e-02, 3.9931e-02, 3.3061e-02, 2.7315e-02,
240
+ 2.2515e-02, 1.8512e-02, 1.5177e-02, 1.2403e-02, 1.0101e-02, 8.1924e-03,
241
+ 6.6143e-03, 5.3120e-03, 4.2400e-03, 3.3599e-03, 2.6396e-03, 2.0520e-03,
242
+ 1.5746e-03, 1.1882e-03, 8.7713e-04, 6.2810e-04, 4.3007e-04, 2.7384e-04,
243
+ 2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
244
+ 1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
245
+ 4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
246
+ 1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
247
+ ], device=torch_device
248
+ )
249
+ # fmt: on
250
+
251
+ # input sanity checks: if these change, the output will also change
252
+ config = LlamaConfig()
253
+ self.assertEqual(config.rope_scaling, None)
254
+ self.assertEqual(config.hidden_size, 4096)
255
+ self.assertEqual(config.num_attention_heads, 32)
256
+ self.assertEqual(config.rope_theta, 10000.0)
257
+ self.assertFalse(hasattr(config, "partial_rotary_factor"))
258
+
259
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
260
+ default_inv_freq, _ = rope_fn(config=config, device=torch_device)
261
+
262
+ # Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
263
+ # `0.1 * math.log(factor) + 1.0`
264
+ rope_fn = ROPE_INIT_FUNCTIONS["yarn"]
265
+ for factor in (2.0, 10.0, 20.0):
266
+ config.rope_scaling = {"rope_type": "yarn", "factor": factor}
267
+ _, attention_scale = rope_fn(config=config, device=torch_device)
268
+ self.assertEqual(attention_scale, 0.1 * math.log(factor) + 1.0)
269
+
270
+ config.rope_scaling = {"rope_type": "yarn", "factor": factor, "attention_factor": 0.5}
271
+ _, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
272
+ self.assertEqual(attention_scale, 0.5)
273
+
274
+ # Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
275
+ # Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
276
+ # `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
277
+ # (note: adds a margin to the test for numerical stability)
278
+ factor = 10.0
279
+ margin = 1e-8
280
+ config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
281
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
282
+ is_bounded_by_factor = [
283
+ ((default_inv_freq[idx] / factor) - margin) <= yarn_inv_freq_value <= (default_inv_freq[idx] + margin)
284
+ for idx, yarn_inv_freq_value in enumerate(inv_freq)
285
+ ]
286
+ self.assertTrue(all(is_bounded_by_factor))
287
+
288
+ # super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
289
+ # values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
290
+ config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 1000, "beta_slow": 1}
291
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
292
+ is_interpolating = [
293
+ yarn_inv_freq_value < (default_inv_freq[idx] + margin) for idx, yarn_inv_freq_value in enumerate(inv_freq)
294
+ ]
295
+ self.assertFalse(is_interpolating[0])
296
+ self.assertTrue(all(is_interpolating[1:]))
297
+ torch.testing.assert_close(inv_freq[-20:], default_inv_freq[-20:] / factor)
298
+
299
+ # Check 3: numerical snapshot to avoid regressions
300
+ config.rope_scaling = {"rope_type": "yarn", "factor": factor, "beta_fast": 32, "beta_slow": 1}
301
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
302
+ torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
303
+
304
+ def test_longrope_rope_numerically(self):
305
+ # input sanity checks: if these change, the output will also change
306
+ config = LlamaConfig()
307
+ self.assertEqual(config.rope_scaling, None)
308
+ self.assertEqual(config.hidden_size, 4096)
309
+ self.assertEqual(config.num_attention_heads, 32)
310
+ self.assertEqual(config.rope_theta, 10000.0)
311
+ self.assertFalse(hasattr(config, "partial_rotary_factor"))
312
+
313
+ # longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on the seq_len
314
+ dim = config.hidden_size // config.num_attention_heads
315
+ short_factor = [2.0] * (dim // 2) # scaling applied when seq_len <= max_position_embeddings
316
+ long_factor = torch.ones(dim // 2).cumsum(0).tolist() # scaling applied when seq_len > max_position_embeddings
317
+
318
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
319
+ default_inv_freq, _ = rope_fn(config=config, device=torch_device)
320
+
321
+ # Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
322
+ # `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))`
323
+ rope_fn = ROPE_INIT_FUNCTIONS["longrope"]
324
+ max_position_embeddings = config.max_position_embeddings
325
+ for factor in (2.0, 10.0, 20.0):
326
+ config.rope_scaling = {
327
+ "rope_type": "longrope",
328
+ "factor": factor,
329
+ "short_factor": short_factor,
330
+ "long_factor": long_factor,
331
+ }
332
+ _, attention_scale = rope_fn(config=config, device=torch_device)
333
+ self.assertEqual(attention_scale, math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings)))
334
+
335
+ config.rope_scaling = {
336
+ "rope_type": "longrope",
337
+ "factor": factor,
338
+ "short_factor": short_factor,
339
+ "long_factor": long_factor,
340
+ "attention_factor": 0.5,
341
+ }
342
+ _, attention_scale = rope_fn(config=config, device=torch_device, seq_len=1)
343
+ self.assertEqual(attention_scale, 0.5)
344
+
345
+ config.rope_scaling = {
346
+ "rope_type": "longrope",
347
+ "factor": factor,
348
+ "short_factor": short_factor,
349
+ "long_factor": long_factor,
350
+ }
351
+ self.assertEqual(config.rope_scaling.get("attention_factor"), None)
352
+ # Verify that "TypeError: '<' not supported between instances of 'NoneType' and 'int'" is not raised.
353
+ rope_config_validation(config)
354
+
355
+ # Check 2: seq_len == 0 -> short factor is applied to the default frequencies
356
+ config.rope_scaling = {
357
+ "rope_type": "longrope",
358
+ "factor": 1.0,
359
+ "short_factor": short_factor,
360
+ "long_factor": long_factor,
361
+ }
362
+ inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=0)
363
+ torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(short_factor).to(torch_device))
364
+
365
+ # Check 3: seq_len > max_position_embeddings -> long factor is applied to the default frequencies
366
+ inv_freq, _ = rope_fn(config=config, device=torch_device, seq_len=config.max_position_embeddings + 1)
367
+ torch.testing.assert_close(inv_freq, default_inv_freq / torch.tensor(long_factor).to(torch_device))
368
+
369
+ def test_llama3_rope_numerically(self):
370
+ # fmt: off
371
+ EXPECTED_INV_FREQ = torch.tensor(
372
+ [
373
+ 1.0000e+00, 8.6596e-01, 7.4989e-01, 6.4938e-01, 5.6234e-01, 4.8697e-01,
374
+ 4.2170e-01, 3.6517e-01, 3.1623e-01, 2.7384e-01, 2.3714e-01, 2.0535e-01,
375
+ 1.7783e-01, 1.5399e-01, 1.3335e-01, 1.1548e-01, 1.0000e-01, 8.6596e-02,
376
+ 7.4989e-02, 6.4938e-02, 5.6234e-02, 4.8697e-02, 4.2170e-02, 3.6517e-02,
377
+ 3.1623e-02, 2.7384e-02, 2.3714e-02, 2.0535e-02, 1.7783e-02, 1.5399e-02,
378
+ 1.3335e-02, 1.0730e-02, 7.7785e-03, 5.6009e-03, 3.9991e-03, 2.8248e-03,
379
+ 1.9675e-03, 1.3449e-03, 8.9549e-04, 5.7363e-04, 3.4539e-04, 2.7384e-04,
380
+ 2.3714e-04, 2.0535e-04, 1.7783e-04, 1.5399e-04, 1.3335e-04, 1.1548e-04,
381
+ 1.0000e-04, 8.6596e-05, 7.4989e-05, 6.4938e-05, 5.6234e-05, 4.8697e-05,
382
+ 4.2170e-05, 3.6517e-05, 3.1623e-05, 2.7384e-05, 2.3714e-05, 2.0535e-05,
383
+ 1.7783e-05, 1.5399e-05, 1.3335e-05, 1.1548e-05
384
+ ], device=torch_device
385
+ )
386
+ # fmt: on
387
+
388
+ # input sanity checks: if these change, the output will also change
389
+ config = LlamaConfig()
390
+ self.assertEqual(config.rope_scaling, None)
391
+ self.assertEqual(config.hidden_size, 4096)
392
+ self.assertEqual(config.num_attention_heads, 32)
393
+ self.assertEqual(config.rope_theta, 10000.0)
394
+ self.assertFalse(hasattr(config, "partial_rotary_factor"))
395
+
396
+ rope_fn = ROPE_INIT_FUNCTIONS["default"]
397
+ default_inv_freq, _ = rope_fn(config=config, device=torch_device)
398
+
399
+ # Check 1: `attention_factor` is always 1
400
+ rope_fn = ROPE_INIT_FUNCTIONS["llama3"]
401
+ for factor in (2.0, 10.0, 20.0):
402
+ config.rope_scaling = {
403
+ "rope_type": "llama3",
404
+ "factor": factor,
405
+ "original_max_position_embeddings": 2048,
406
+ "low_freq_factor": 1,
407
+ "high_freq_factor": 4,
408
+ }
409
+ _, attention_scale = rope_fn(config=config, device=torch_device)
410
+ self.assertEqual(attention_scale, 1.0)
411
+
412
+ # Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
413
+ # `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequencies see no change, medium
414
+ # frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
415
+ # is considered low, medium, and high frequencies.
416
+ factor = 10.0
417
+ config.rope_scaling = {
418
+ "rope_type": "llama3",
419
+ "factor": factor,
420
+ "original_max_position_embeddings": 2048,
421
+ "low_freq_factor": 1,
422
+ "high_freq_factor": 4,
423
+ }
424
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
425
+ is_bounded_by_factor = [
426
+ (default_inv_freq[idx] / factor) <= llama3_inv_freq_value <= default_inv_freq[idx]
427
+ for idx, llama3_inv_freq_value in enumerate(inv_freq)
428
+ ]
429
+ self.assertTrue(all(is_bounded_by_factor))
430
+
431
+ # if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
432
+ # scaled
433
+ config.rope_scaling = config.rope_scaling = {
434
+ "rope_type": "llama3",
435
+ "factor": factor,
436
+ "original_max_position_embeddings": 2048,
437
+ "low_freq_factor": 1,
438
+ "high_freq_factor": 1000,
439
+ }
440
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
441
+ is_scaled = [yarn_inv_freq_value < default_inv_freq[idx] for idx, yarn_inv_freq_value in enumerate(inv_freq)]
442
+ self.assertTrue(all(is_scaled))
443
+
444
+ # Check 3: numerical snapshot to avoid regressions
445
+ config.rope_scaling = {
446
+ "rope_type": "llama3",
447
+ "factor": factor,
448
+ "original_max_position_embeddings": 2048,
449
+ "low_freq_factor": 1,
450
+ "high_freq_factor": 4,
451
+ }
452
+ inv_freq, _ = rope_fn(config=config, device=torch_device)
453
+ torch.testing.assert_close(inv_freq, EXPECTED_INV_FREQ)
docs/transformers/tests/utils/test_modeling_tf_core.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ import copy
19
+ import os
20
+ import tempfile
21
+ from importlib import import_module
22
+ from math import isnan
23
+
24
+ from transformers import is_tf_available
25
+ from transformers.models.auto import get_values
26
+ from transformers.testing_utils import require_tf, slow
27
+
28
+ from ..test_modeling_tf_common import ids_tensor
29
+
30
+
31
+ if is_tf_available():
32
+ import numpy as np
33
+ import tensorflow as tf
34
+
35
+ from transformers import (
36
+ TF_MODEL_FOR_CAUSAL_LM_MAPPING,
37
+ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
38
+ TF_MODEL_FOR_MASKED_LM_MAPPING,
39
+ TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
40
+ TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
41
+ TF_MODEL_FOR_PRETRAINING_MAPPING,
42
+ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
43
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
44
+ TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
45
+ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
46
+ TFSharedEmbeddings,
47
+ )
48
+ from transformers.modeling_tf_utils import keras
49
+
50
+
51
+ @require_tf
52
+ class TFCoreModelTesterMixin:
53
+ model_tester = None
54
+ all_model_classes = ()
55
+ all_generative_model_classes = ()
56
+ test_mismatched_shapes = True
57
+ test_resize_embeddings = True
58
+ test_head_masking = True
59
+ is_encoder_decoder = False
60
+
61
+ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
62
+ inputs_dict = copy.deepcopy(inputs_dict)
63
+
64
+ if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
65
+ inputs_dict = {
66
+ k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
67
+ if isinstance(v, tf.Tensor) and v.ndim > 0
68
+ else v
69
+ for k, v in inputs_dict.items()
70
+ }
71
+
72
+ if return_labels:
73
+ if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
74
+ inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
75
+ elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
76
+ inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
77
+ inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
78
+ elif model_class in [
79
+ *get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
80
+ *get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
81
+ ]:
82
+ inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
83
+ elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
84
+ inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
85
+ elif model_class in [
86
+ *get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
87
+ *get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
88
+ *get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
89
+ *get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
90
+ *get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
91
+ ]:
92
+ inputs_dict["labels"] = tf.zeros(
93
+ (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
94
+ )
95
+ return inputs_dict
96
+
97
+ @slow
98
+ def test_graph_mode(self):
99
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
100
+ for model_class in self.all_model_classes[:2]:
101
+ inputs = self._prepare_for_class(inputs_dict, model_class)
102
+ model = model_class(config)
103
+
104
+ @tf.function
105
+ def run_in_graph_mode():
106
+ return model(inputs)
107
+
108
+ outputs = run_in_graph_mode()
109
+ self.assertIsNotNone(outputs)
110
+
111
+ @slow
112
+ def test_xla_mode(self):
113
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
114
+ for model_class in self.all_model_classes[:2]:
115
+ inputs = self._prepare_for_class(inputs_dict, model_class)
116
+ model = model_class(config)
117
+
118
+ @tf.function(experimental_compile=True)
119
+ def run_in_graph_mode():
120
+ return model(inputs)
121
+
122
+ outputs = run_in_graph_mode()
123
+ self.assertIsNotNone(outputs)
124
+
125
+ @slow
126
+ def test_xla_fit(self):
127
+ # This is a copy of the test_keras_fit method, but we use XLA compilation instead of eager
128
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
129
+ for model_class in self.all_model_classes[:2]:
130
+ model = model_class(config)
131
+ if getattr(model, "hf_compute_loss", None):
132
+ # Test that model correctly compute the loss with kwargs
133
+ prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
134
+ # Is there a better way to remove these decoder inputs?
135
+ prepared_for_class = {
136
+ key: val
137
+ for key, val in prepared_for_class.items()
138
+ if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids")
139
+ }
140
+
141
+ possible_label_cols = {
142
+ "labels",
143
+ "label",
144
+ "label_ids",
145
+ "start_positions",
146
+ "start_position",
147
+ "end_positions",
148
+ "end_position",
149
+ "next_sentence_label",
150
+ }
151
+ label_names = possible_label_cols.intersection(set(prepared_for_class))
152
+ self.assertGreater(len(label_names), 0, msg="No matching label names found!")
153
+ labels = {key: val for key, val in prepared_for_class.items() if key in label_names}
154
+ inputs_minus_labels = {key: val for key, val in prepared_for_class.items() if key not in label_names}
155
+ self.assertGreater(len(inputs_minus_labels), 0)
156
+
157
+ # Make sure it works with XLA!
158
+ model.compile(optimizer=keras.optimizers.SGD(0.0), jit_compile=True)
159
+ # Make sure the model fits without crashing regardless of where we pass the labels
160
+ history = model.fit(
161
+ prepared_for_class,
162
+ validation_data=prepared_for_class,
163
+ steps_per_epoch=1,
164
+ validation_steps=1,
165
+ shuffle=False,
166
+ verbose=0,
167
+ )
168
+ loss = history.history["loss"][0]
169
+ self.assertTrue(not isnan(loss))
170
+ val_loss = history.history["val_loss"][0]
171
+ self.assertTrue(not isnan(val_loss))
172
+
173
+ # Now test it with separate labels, to make sure that path works in XLA too.
174
+ model = model_class(config)
175
+ model.compile(optimizer=keras.optimizers.SGD(0.0), jit_compile=True)
176
+ history = model.fit(
177
+ inputs_minus_labels,
178
+ labels,
179
+ validation_data=(inputs_minus_labels, labels),
180
+ steps_per_epoch=1,
181
+ validation_steps=1,
182
+ shuffle=False,
183
+ verbose=0,
184
+ )
185
+
186
+ loss = history.history["loss"][0]
187
+ self.assertTrue(not isnan(loss))
188
+ val_loss = history.history["val_loss"][0]
189
+ self.assertTrue(not isnan(val_loss))
190
+
191
+ @slow
192
+ def test_saved_model_creation_extended(self):
193
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
194
+ config.output_hidden_states = True
195
+ config.output_attentions = True
196
+
197
+ if hasattr(config, "use_cache"):
198
+ config.use_cache = True
199
+
200
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", self.model_tester.seq_length)
201
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
202
+
203
+ for model_class in self.all_model_classes[:2]:
204
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
205
+ model = model_class(config)
206
+ model.build_in_name_scope()
207
+ num_out = len(model(class_inputs_dict))
208
+
209
+ for key in list(class_inputs_dict.keys()):
210
+ # Remove keys not in the serving signature, as the SavedModel will not be compiled to deal with them
211
+ if key not in model.input_signature:
212
+ del class_inputs_dict[key]
213
+ # Check it's a tensor, in case the inputs dict has some bools in it too
214
+ elif isinstance(class_inputs_dict[key], tf.Tensor) and class_inputs_dict[key].dtype.is_integer:
215
+ class_inputs_dict[key] = tf.cast(class_inputs_dict[key], tf.int32)
216
+
217
+ if set(class_inputs_dict.keys()) != set(model.input_signature.keys()):
218
+ continue # Some models have inputs that the preparation functions don't create, we skip those
219
+
220
+ with tempfile.TemporaryDirectory() as tmpdirname:
221
+ model.save_pretrained(tmpdirname, saved_model=True)
222
+ saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
223
+ model = keras.models.load_model(saved_model_dir)
224
+ outputs = model(class_inputs_dict)
225
+
226
+ if self.is_encoder_decoder:
227
+ output_hidden_states = outputs["encoder_hidden_states"]
228
+ output_attentions = outputs["encoder_attentions"]
229
+ else:
230
+ output_hidden_states = outputs["hidden_states"]
231
+ output_attentions = outputs["attentions"]
232
+
233
+ self.assertEqual(len(outputs), num_out)
234
+
235
+ expected_num_layers = getattr(
236
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
237
+ )
238
+
239
+ self.assertEqual(len(output_hidden_states), expected_num_layers)
240
+ self.assertListEqual(
241
+ list(output_hidden_states[0].shape[-2:]),
242
+ [self.model_tester.seq_length, self.model_tester.hidden_size],
243
+ )
244
+
245
+ self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
246
+ self.assertListEqual(
247
+ list(output_attentions[0].shape[-3:]),
248
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
249
+ )
250
+
251
+ @slow
252
+ def test_mixed_precision(self):
253
+ keras.mixed_precision.set_global_policy("mixed_float16")
254
+
255
+ # try/finally block to ensure subsequent tests run in float32
256
+ try:
257
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
258
+ for model_class in self.all_model_classes[:2]:
259
+ class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
260
+ model = model_class(config)
261
+ outputs = model(class_inputs_dict)
262
+
263
+ self.assertIsNotNone(outputs)
264
+ finally:
265
+ keras.mixed_precision.set_global_policy("float32")
266
+
267
+ @slow
268
+ def test_train_pipeline_custom_model(self):
269
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
270
+ # head_mask and decoder_head_mask has different shapes than other input args
271
+ if "head_mask" in inputs_dict:
272
+ del inputs_dict["head_mask"]
273
+ if "decoder_head_mask" in inputs_dict:
274
+ del inputs_dict["decoder_head_mask"]
275
+ if "cross_attn_head_mask" in inputs_dict:
276
+ del inputs_dict["cross_attn_head_mask"]
277
+ tf_main_layer_classes = {
278
+ module_member
279
+ for model_class in self.all_model_classes
280
+ for module in (import_module(model_class.__module__),)
281
+ for module_member_name in dir(module)
282
+ if module_member_name.endswith("MainLayer")
283
+ for module_member in (getattr(module, module_member_name),)
284
+ if isinstance(module_member, type)
285
+ and keras.layers.Layer in module_member.__bases__
286
+ and getattr(module_member, "_keras_serializable", False)
287
+ }
288
+
289
+ for main_layer_class in tf_main_layer_classes:
290
+ # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
291
+ if "T5" in main_layer_class.__name__:
292
+ # Take the same values than in TFT5ModelTester for this shared layer
293
+ shared = TFSharedEmbeddings(self.model_tester.vocab_size, self.model_tester.hidden_size, name="shared")
294
+ config.use_cache = False
295
+ main_layer = main_layer_class(config, embed_tokens=shared)
296
+ else:
297
+ main_layer = main_layer_class(config)
298
+
299
+ symbolic_inputs = {
300
+ name: keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
301
+ }
302
+
303
+ if hasattr(self.model_tester, "num_labels"):
304
+ num_labels = self.model_tester.num_labels
305
+ else:
306
+ num_labels = 2
307
+
308
+ X = tf.data.Dataset.from_tensor_slices(
309
+ (inputs_dict, np.ones((self.model_tester.batch_size, self.model_tester.seq_length, num_labels, 1)))
310
+ ).batch(1)
311
+
312
+ hidden_states = main_layer(symbolic_inputs)[0]
313
+ outputs = keras.layers.Dense(num_labels, activation="softmax", name="outputs")(hidden_states)
314
+ model = keras.models.Model(inputs=symbolic_inputs, outputs=[outputs])
315
+
316
+ model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"])
317
+ model.fit(X, epochs=1)
318
+
319
+ with tempfile.TemporaryDirectory() as tmpdirname:
320
+ filepath = os.path.join(tmpdirname, "keras_model.h5")
321
+ model.save(filepath)
322
+ if "T5" in main_layer_class.__name__:
323
+ model = keras.models.load_model(
324
+ filepath,
325
+ custom_objects={
326
+ main_layer_class.__name__: main_layer_class,
327
+ "TFSharedEmbeddings": TFSharedEmbeddings,
328
+ },
329
+ )
330
+ else:
331
+ model = keras.models.load_model(
332
+ filepath, custom_objects={main_layer_class.__name__: main_layer_class}
333
+ )
334
+ assert isinstance(model, keras.Model)
335
+ model(inputs_dict)
336
+
337
+ @slow
338
+ def test_graph_mode_with_inputs_embeds(self):
339
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
340
+
341
+ for model_class in self.all_model_classes[:2]:
342
+ model = model_class(config)
343
+
344
+ inputs = copy.deepcopy(inputs_dict)
345
+
346
+ if not self.is_encoder_decoder:
347
+ input_ids = inputs["input_ids"]
348
+ del inputs["input_ids"]
349
+ else:
350
+ encoder_input_ids = inputs["input_ids"]
351
+ decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
352
+ del inputs["input_ids"]
353
+ inputs.pop("decoder_input_ids", None)
354
+
355
+ if not self.is_encoder_decoder:
356
+ inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
357
+ else:
358
+ inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
359
+ inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
360
+
361
+ inputs = self._prepare_for_class(inputs, model_class)
362
+
363
+ @tf.function
364
+ def run_in_graph_mode():
365
+ return model(inputs)
366
+
367
+ outputs = run_in_graph_mode()
368
+ self.assertIsNotNone(outputs)
369
+
370
+ def _generate_random_bad_tokens(self, num_bad_tokens, model):
371
+ # special tokens cannot be bad tokens
372
+ special_tokens = []
373
+ if model.config.bos_token_id is not None:
374
+ special_tokens.append(model.config.bos_token_id)
375
+ if model.config.pad_token_id is not None:
376
+ special_tokens.append(model.config.pad_token_id)
377
+ if model.config.eos_token_id is not None:
378
+ special_tokens.append(model.config.eos_token_id)
379
+
380
+ # create random bad tokens that are not special tokens
381
+ bad_tokens = []
382
+ while len(bad_tokens) < num_bad_tokens:
383
+ token = tf.squeeze(ids_tensor((1, 1), self.model_tester.vocab_size), 0).numpy()[0]
384
+ if token not in special_tokens:
385
+ bad_tokens.append(token)
386
+ return bad_tokens
387
+
388
+ def _check_generated_ids(self, output_ids):
389
+ for token_id in output_ids[0].numpy().tolist():
390
+ self.assertGreaterEqual(token_id, 0)
391
+ self.assertLess(token_id, self.model_tester.vocab_size)
392
+
393
+ def _check_match_tokens(self, generated_ids, bad_words_ids):
394
+ # for all bad word tokens
395
+ for bad_word_ids in bad_words_ids:
396
+ # for all slices in batch
397
+ for generated_ids_slice in generated_ids:
398
+ # for all word idx
399
+ for i in range(len(bad_word_ids), len(generated_ids_slice)):
400
+ # if tokens match
401
+ if generated_ids_slice[i - len(bad_word_ids) : i] == bad_word_ids:
402
+ return True
403
+ return False
docs/transformers/tests/utils/test_modeling_tf_utils.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import random
21
+ import tempfile
22
+ import unittest
23
+ import unittest.mock as mock
24
+
25
+ from huggingface_hub import HfFolder, snapshot_download
26
+ from requests.exceptions import HTTPError
27
+
28
+ from transformers import is_tf_available
29
+ from transformers.configuration_utils import PretrainedConfig
30
+ from transformers.testing_utils import ( # noqa: F401
31
+ TOKEN,
32
+ USER,
33
+ CaptureLogger,
34
+ TemporaryHubRepo,
35
+ is_staging_test,
36
+ require_safetensors,
37
+ require_tf,
38
+ slow,
39
+ )
40
+ from transformers.utils import (
41
+ SAFE_WEIGHTS_INDEX_NAME,
42
+ SAFE_WEIGHTS_NAME,
43
+ TF2_WEIGHTS_INDEX_NAME,
44
+ TF2_WEIGHTS_NAME,
45
+ logging,
46
+ )
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ if is_tf_available():
53
+ import h5py
54
+ import numpy as np
55
+ import tensorflow as tf
56
+
57
+ from transformers import (
58
+ BertConfig,
59
+ RagRetriever,
60
+ TFBertForSequenceClassification,
61
+ TFBertModel,
62
+ TFRagModel,
63
+ )
64
+ from transformers.modeling_tf_utils import keras, tf_shard_checkpoint, unpack_inputs
65
+ from transformers.tf_utils import stable_softmax
66
+
67
+ tf.config.experimental.enable_tensor_float_32_execution(False)
68
+
69
+
70
+ @require_tf
71
+ class TFModelUtilsTest(unittest.TestCase):
72
+ def test_cached_files_are_used_when_internet_is_down(self):
73
+ # A mock response for an HTTP head request to emulate server down
74
+ response_mock = mock.Mock()
75
+ response_mock.status_code = 500
76
+ response_mock.headers = {}
77
+ response_mock.raise_for_status.side_effect = HTTPError
78
+ response_mock.json.return_value = {}
79
+
80
+ # Download this model to make sure it's in the cache.
81
+ _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
82
+
83
+ # Under the mock environment we get a 500 error when trying to reach the model.
84
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
85
+ _ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
86
+ # This check we did call the fake head request
87
+ mock_head.assert_called()
88
+
89
+ # tests whether the unpack_inputs function behaves as expected
90
+ def test_unpack_inputs(self):
91
+ class DummyModel:
92
+ def __init__(self):
93
+ config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
94
+ self.config = PretrainedConfig(**config_kwargs)
95
+ self.main_input_name = "input_ids"
96
+
97
+ @unpack_inputs
98
+ def call(
99
+ self,
100
+ input_ids=None,
101
+ past_key_values=None,
102
+ output_attentions=None,
103
+ output_hidden_states=None,
104
+ return_dict=None,
105
+ ):
106
+ return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
107
+
108
+ @unpack_inputs
109
+ def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
110
+ return pixel_values, output_attentions, output_hidden_states, return_dict
111
+
112
+ dummy_model = DummyModel()
113
+ input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
114
+ past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
115
+ pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
116
+
117
+ # test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
118
+ output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
119
+ tf.debugging.assert_equal(output[0], input_ids)
120
+ tf.debugging.assert_equal(output[1], past_key_values)
121
+ self.assertFalse(output[2])
122
+ self.assertFalse(output[3])
123
+ self.assertFalse(output[4])
124
+
125
+ # test case 2: Same as above, but with positional arguments.
126
+ output = dummy_model.call(input_ids, past_key_values)
127
+ tf.debugging.assert_equal(output[0], input_ids)
128
+ tf.debugging.assert_equal(output[1], past_key_values)
129
+ self.assertFalse(output[2])
130
+ self.assertFalse(output[3])
131
+ self.assertFalse(output[4])
132
+
133
+ # test case 3: We can also pack everything in the first input.
134
+ output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
135
+ tf.debugging.assert_equal(output[0], input_ids)
136
+ tf.debugging.assert_equal(output[1], past_key_values)
137
+ self.assertFalse(output[2])
138
+ self.assertFalse(output[3])
139
+ self.assertFalse(output[4])
140
+
141
+ # test case 4: Explicit boolean arguments should override the config.
142
+ output = dummy_model.call(
143
+ input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
144
+ )
145
+ tf.debugging.assert_equal(output[0], input_ids)
146
+ tf.debugging.assert_equal(output[1], past_key_values)
147
+ self.assertFalse(output[2])
148
+ self.assertFalse(output[3])
149
+ self.assertTrue(output[4])
150
+
151
+ # test case 5: Unexpected arguments should raise an exception.
152
+ with self.assertRaises(ValueError):
153
+ output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
154
+
155
+ # test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
156
+ # decorated function as its main input.
157
+ output = dummy_model.foo(pixel_values=pixel_values)
158
+ tf.debugging.assert_equal(output[0], pixel_values)
159
+ self.assertFalse(output[1])
160
+ self.assertFalse(output[2])
161
+ self.assertFalse(output[3])
162
+
163
+ # Tests whether the stable softmax is stable on CPU, with and without XLA
164
+ def test_xla_stable_softmax(self):
165
+ large_penalty = -1e9
166
+ n_tokens = 10
167
+ batch_size = 8
168
+
169
+ def masked_softmax(x, boolean_mask):
170
+ numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
171
+ masked_x = x + numerical_mask
172
+ return stable_softmax(masked_x)
173
+
174
+ xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
175
+ xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
176
+ x = tf.random.normal((batch_size, n_tokens))
177
+
178
+ # Same outcome regardless of the boolean mask here
179
+ masked_tokens = random.randint(0, n_tokens)
180
+ boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
181
+
182
+ # We can randomly mask a random numerical input OUTSIDE XLA
183
+ numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
184
+ masked_x = x + numerical_mask
185
+ xla_out = xla_stable_softmax(masked_x)
186
+ out = stable_softmax(masked_x)
187
+ assert tf.experimental.numpy.allclose(xla_out, out)
188
+
189
+ # The stable softmax has the same output as the original softmax
190
+ unstable_out = tf.nn.softmax(masked_x)
191
+ assert tf.experimental.numpy.allclose(unstable_out, out)
192
+
193
+ # We can randomly mask a random numerical input INSIDE XLA
194
+ xla_out = xla_masked_softmax(x, boolean_mask)
195
+ out = masked_softmax(x, boolean_mask)
196
+ assert tf.experimental.numpy.allclose(xla_out, out)
197
+
198
+ def test_checkpoint_sharding_from_hub(self):
199
+ model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
200
+ # the model above is the same as the model below, just a sharded version.
201
+ ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
202
+ for p1, p2 in zip(model.weights, ref_model.weights):
203
+ assert np.allclose(p1.numpy(), p2.numpy())
204
+
205
+ def test_sharded_checkpoint_with_prefix(self):
206
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
207
+ sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
208
+ for p1, p2 in zip(model.weights, sharded_model.weights):
209
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
210
+ self.assertTrue(p1.name.startswith("a/b/"))
211
+ self.assertTrue(p2.name.startswith("a/b/"))
212
+
213
+ def test_sharded_checkpoint_transfer(self):
214
+ # If this doesn't throw an error then the test passes
215
+ TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
216
+
217
+ def test_shard_checkpoint(self):
218
+ # This is the model we will use, total size 340,000 bytes.
219
+ model = keras.Sequential(
220
+ [
221
+ keras.layers.Dense(200, use_bias=False), # size 80,000
222
+ keras.layers.Dense(200, use_bias=False), # size 160,000
223
+ keras.layers.Dense(100, use_bias=False), # size 80,000
224
+ keras.layers.Dense(50, use_bias=False), # size 20,000
225
+ ]
226
+ )
227
+ inputs = tf.zeros((1, 100), dtype=tf.float32)
228
+ model(inputs)
229
+ weights = model.weights
230
+ weights_dict = {w.name: w for w in weights}
231
+ with self.subTest("No shard when max size is bigger than model size"):
232
+ shards, index = tf_shard_checkpoint(weights)
233
+ self.assertIsNone(index)
234
+ self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
235
+
236
+ with self.subTest("Test sharding, no weights bigger than max size"):
237
+ shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
238
+ # Split is first two layers then last two.
239
+ self.assertDictEqual(
240
+ index,
241
+ {
242
+ "metadata": {"total_size": 340000},
243
+ "weight_map": {
244
+ "dense/kernel:0": "tf_model-00001-of-00002.h5",
245
+ "dense_1/kernel:0": "tf_model-00001-of-00002.h5",
246
+ "dense_2/kernel:0": "tf_model-00002-of-00002.h5",
247
+ "dense_3/kernel:0": "tf_model-00002-of-00002.h5",
248
+ },
249
+ },
250
+ )
251
+
252
+ shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
253
+ shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
254
+ self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
255
+
256
+ with self.subTest("Test sharding with weights bigger than max size"):
257
+ shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
258
+ # Split is first layer, second layer then last 2.
259
+ self.assertDictEqual(
260
+ index,
261
+ {
262
+ "metadata": {"total_size": 340000},
263
+ "weight_map": {
264
+ "dense/kernel:0": "tf_model-00001-of-00003.h5",
265
+ "dense_1/kernel:0": "tf_model-00002-of-00003.h5",
266
+ "dense_2/kernel:0": "tf_model-00003-of-00003.h5",
267
+ "dense_3/kernel:0": "tf_model-00003-of-00003.h5",
268
+ },
269
+ },
270
+ )
271
+
272
+ shard1 = [weights_dict["dense/kernel:0"]]
273
+ shard2 = [weights_dict["dense_1/kernel:0"]]
274
+ shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
275
+ self.assertDictEqual(
276
+ shards,
277
+ {
278
+ "tf_model-00001-of-00003.h5": shard1,
279
+ "tf_model-00002-of-00003.h5": shard2,
280
+ "tf_model-00003-of-00003.h5": shard3,
281
+ },
282
+ )
283
+
284
+ @slow
285
+ def test_special_layer_name_sharding(self):
286
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
287
+ model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
288
+
289
+ with tempfile.TemporaryDirectory() as tmp_dir:
290
+ for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
291
+ model.save_pretrained(tmp_dir, max_shard_size=max_size)
292
+ ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
293
+ for p1, p2 in zip(model.weights, ref_model.weights):
294
+ assert np.allclose(p1.numpy(), p2.numpy())
295
+
296
+ @require_safetensors
297
+ def test_checkpoint_sharding_local(self):
298
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
299
+
300
+ with tempfile.TemporaryDirectory() as tmp_dir:
301
+ # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
302
+ for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
303
+ model.save_pretrained(tmp_dir, max_shard_size=max_size)
304
+
305
+ # Get each shard file and its size
306
+ shard_to_size = {}
307
+ for shard in os.listdir(tmp_dir):
308
+ if shard.endswith(".h5"):
309
+ shard_file = os.path.join(tmp_dir, shard)
310
+ shard_to_size[shard_file] = os.path.getsize(shard_file)
311
+
312
+ index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
313
+ # Check there is an index but no regular weight file
314
+ self.assertTrue(os.path.isfile(index_file))
315
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
316
+
317
+ # Check a file is bigger than max_size only when it has a single weight
318
+ for shard_file, size in shard_to_size.items():
319
+ if max_size.endswith("kiB"):
320
+ max_size_int = int(max_size[:-3]) * 2**10
321
+ else:
322
+ max_size_int = int(max_size[:-2]) * 10**3
323
+ # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
324
+ # the size asked for (since we count parameters)
325
+ if size >= max_size_int + 50000:
326
+ with h5py.File(shard_file, "r") as state_file:
327
+ self.assertEqual(len(state_file), 1)
328
+
329
+ # Check the index and the shard files found match
330
+ with open(index_file, encoding="utf-8") as f:
331
+ index = json.loads(f.read())
332
+
333
+ all_shards = set(index["weight_map"].values())
334
+ shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
335
+ self.assertSetEqual(all_shards, shards_found)
336
+
337
+ # Finally, check the model can be reloaded
338
+ new_model = TFBertModel.from_pretrained(tmp_dir)
339
+
340
+ model.build_in_name_scope()
341
+ new_model.build_in_name_scope()
342
+
343
+ for p1, p2 in zip(model.weights, new_model.weights):
344
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
345
+
346
+ def test_safetensors_checkpoint_sharding_local(self):
347
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
348
+
349
+ with tempfile.TemporaryDirectory() as tmp_dir:
350
+ # We use the same folder for various sizes to make sure a new save erases the old checkpoint.
351
+ for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
352
+ model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
353
+
354
+ # Get each shard file and its size
355
+ shard_to_size = {}
356
+ for shard in os.listdir(tmp_dir):
357
+ if shard.endswith(".h5"):
358
+ shard_file = os.path.join(tmp_dir, shard)
359
+ shard_to_size[shard_file] = os.path.getsize(shard_file)
360
+
361
+ index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
362
+ # Check there is an index but no regular weight file
363
+ self.assertTrue(os.path.isfile(index_file))
364
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
365
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
366
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
367
+
368
+ # Check the index and the shard files found match
369
+ with open(index_file, encoding="utf-8") as f:
370
+ index = json.loads(f.read())
371
+
372
+ all_shards = set(index["weight_map"].values())
373
+ shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
374
+ self.assertSetEqual(all_shards, shards_found)
375
+
376
+ # Finally, check the model can be reloaded
377
+ new_model = TFBertModel.from_pretrained(tmp_dir)
378
+
379
+ model.build_in_name_scope()
380
+ new_model.build_in_name_scope()
381
+
382
+ for p1, p2 in zip(model.weights, new_model.weights):
383
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
384
+
385
+ @slow
386
+ def test_save_pretrained_signatures(self):
387
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
388
+
389
+ # Short custom TF signature function.
390
+ # `input_signature` is specific to BERT.
391
+ @tf.function(
392
+ input_signature=[
393
+ [
394
+ tf.TensorSpec([None, None], tf.int32, name="input_ids"),
395
+ tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
396
+ tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
397
+ ]
398
+ ]
399
+ )
400
+ def serving_fn(input):
401
+ return model(input)
402
+
403
+ # Using default signature (default behavior) overrides 'serving_default'
404
+ with tempfile.TemporaryDirectory() as tmp_dir:
405
+ model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
406
+ model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
407
+ self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
408
+
409
+ # Providing custom signature function
410
+ with tempfile.TemporaryDirectory() as tmp_dir:
411
+ model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
412
+ model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
413
+ self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
414
+
415
+ # Providing multiple custom signature function
416
+ with tempfile.TemporaryDirectory() as tmp_dir:
417
+ model.save_pretrained(
418
+ tmp_dir,
419
+ saved_model=True,
420
+ signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
421
+ )
422
+ model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
423
+ self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
424
+ self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
425
+
426
+ @require_safetensors
427
+ def test_safetensors_save_and_load(self):
428
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
429
+ with tempfile.TemporaryDirectory() as tmp_dir:
430
+ model.save_pretrained(tmp_dir, safe_serialization=True)
431
+ # No tf_model.h5 file, only a model.safetensors
432
+ self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
433
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
434
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
435
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
436
+
437
+ new_model = TFBertModel.from_pretrained(tmp_dir)
438
+
439
+ # Check models are equal
440
+ for p1, p2 in zip(model.weights, new_model.weights):
441
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
442
+
443
+ @require_safetensors
444
+ def test_safetensors_sharded_save_and_load(self):
445
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
446
+ with tempfile.TemporaryDirectory() as tmp_dir:
447
+ model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
448
+ # No tf weights or index file, only a safetensors index
449
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
450
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
451
+ self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
452
+ self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
453
+
454
+ new_model = TFBertModel.from_pretrained(tmp_dir)
455
+
456
+ # Check models are equal
457
+ for p1, p2 in zip(model.weights, new_model.weights):
458
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
459
+
460
+ @require_safetensors
461
+ def test_safetensors_load_from_hub(self):
462
+ tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
463
+
464
+ # Can load from the TF-formatted checkpoint
465
+ safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
466
+
467
+ # Check models are equal
468
+ for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
469
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
470
+
471
+ # Can load from the PyTorch-formatted checkpoint
472
+ safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
473
+
474
+ # Check models are equal
475
+ for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
476
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
477
+
478
+ @require_safetensors
479
+ def test_safetensors_tf_from_tf(self):
480
+ model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
481
+
482
+ with tempfile.TemporaryDirectory() as tmp_dir:
483
+ model.save_pretrained(tmp_dir, safe_serialization=True)
484
+ new_model = TFBertModel.from_pretrained(tmp_dir)
485
+
486
+ for p1, p2 in zip(model.weights, new_model.weights):
487
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
488
+
489
+ @require_safetensors
490
+ def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
491
+ with tempfile.TemporaryDirectory() as tmp_dir:
492
+ path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
493
+
494
+ # This should not raise even if there are two types of sharded weights
495
+ TFBertModel.from_pretrained(path)
496
+
497
+ @require_safetensors
498
+ def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
499
+ # Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
500
+ TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
501
+ # Confirm that we can access the TF weights too
502
+ TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
503
+
504
+ @require_safetensors
505
+ def test_safetensors_load_from_local(self):
506
+ """
507
+ This test checks that we can load safetensors from a checkpoint that only has those on the Hub
508
+ """
509
+ with tempfile.TemporaryDirectory() as tmp:
510
+ location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
511
+ tf_model = TFBertModel.from_pretrained(location)
512
+
513
+ with tempfile.TemporaryDirectory() as tmp:
514
+ location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
515
+ safetensors_model = TFBertModel.from_pretrained(location)
516
+
517
+ for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
518
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
519
+
520
+ @require_safetensors
521
+ def test_safetensors_load_from_hub_from_safetensors_pt(self):
522
+ """
523
+ This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
524
+ saved in the "pt" format.
525
+ """
526
+ tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
527
+
528
+ # Can load from the PyTorch-formatted checkpoint
529
+ safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
530
+ for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
531
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
532
+
533
+ @require_safetensors
534
+ def test_safetensors_load_from_local_from_safetensors_pt(self):
535
+ """
536
+ This test checks that we can load safetensors from a local checkpoint that only has those
537
+ saved in the "pt" format.
538
+ """
539
+ with tempfile.TemporaryDirectory() as tmp:
540
+ location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
541
+ tf_model = TFBertModel.from_pretrained(location)
542
+
543
+ # Can load from the PyTorch-formatted checkpoint
544
+ with tempfile.TemporaryDirectory() as tmp:
545
+ location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
546
+ safetensors_model = TFBertModel.from_pretrained(location)
547
+
548
+ for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
549
+ self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
550
+
551
+ @require_safetensors
552
+ def test_safetensors_load_from_hub_h5_before_safetensors(self):
553
+ """
554
+ This test checks that we'll first download h5 weights before safetensors
555
+ The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
556
+ """
557
+ TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
558
+
559
+ @require_safetensors
560
+ def test_safetensors_load_from_local_h5_before_safetensors(self):
561
+ """
562
+ This test checks that we'll first download h5 weights before safetensors
563
+ The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
564
+ """
565
+ with tempfile.TemporaryDirectory() as tmp:
566
+ location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
567
+ TFBertModel.from_pretrained(location)
568
+
569
+
570
+ @require_tf
571
+ @is_staging_test
572
+ class TFModelPushToHubTester(unittest.TestCase):
573
+ @classmethod
574
+ def setUpClass(cls):
575
+ cls._token = TOKEN
576
+ HfFolder.save_token(TOKEN)
577
+
578
+ def test_push_to_hub(self):
579
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
580
+ config = BertConfig(
581
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
582
+ )
583
+ model = TFBertModel(config)
584
+ # Make sure model is properly initialized
585
+ model.build_in_name_scope()
586
+
587
+ logging.set_verbosity_info()
588
+ logger = logging.get_logger("transformers.utils.hub")
589
+ with CaptureLogger(logger) as cl:
590
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
591
+ logging.set_verbosity_warning()
592
+ # Check the model card was created and uploaded.
593
+ self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
594
+
595
+ new_model = TFBertModel.from_pretrained(tmp_repo.repo_id)
596
+ models_equal = True
597
+ for p1, p2 in zip(model.weights, new_model.weights):
598
+ if not tf.math.reduce_all(p1 == p2):
599
+ models_equal = False
600
+ break
601
+ self.assertTrue(models_equal)
602
+
603
+ def test_push_to_hub_via_save_pretrained(self):
604
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
605
+ config = BertConfig(
606
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
607
+ )
608
+ model = TFBertModel(config)
609
+ # Make sure model is properly initialized
610
+ model.build_in_name_scope()
611
+
612
+ # Push to hub via save_pretrained
613
+ with tempfile.TemporaryDirectory() as tmp_dir:
614
+ model.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
615
+
616
+ new_model = TFBertModel.from_pretrained(tmp_repo.repo_id)
617
+ models_equal = True
618
+ for p1, p2 in zip(model.weights, new_model.weights):
619
+ if not tf.math.reduce_all(p1 == p2):
620
+ models_equal = False
621
+ break
622
+ self.assertTrue(models_equal)
623
+
624
+ def test_push_to_hub_in_organization(self):
625
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
626
+ config = BertConfig(
627
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
628
+ )
629
+ model = TFBertModel(config)
630
+ # Make sure model is properly initialized
631
+ model.build_in_name_scope()
632
+
633
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
634
+
635
+ new_model = TFBertModel.from_pretrained(tmp_repo.repo_id)
636
+ models_equal = True
637
+ for p1, p2 in zip(model.weights, new_model.weights):
638
+ if not tf.math.reduce_all(p1 == p2):
639
+ models_equal = False
640
+ break
641
+ self.assertTrue(models_equal)
642
+
643
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
644
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
645
+ config = BertConfig(
646
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
647
+ )
648
+ model = TFBertModel(config)
649
+ # Make sure model is properly initialized
650
+ model.build_in_name_scope()
651
+
652
+ # Push to hub via save_pretrained
653
+ with tempfile.TemporaryDirectory() as tmp_dir:
654
+ model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id=tmp_repo.repo_id)
655
+
656
+ new_model = TFBertModel.from_pretrained(tmp_repo.repo_id)
657
+ models_equal = True
658
+ for p1, p2 in zip(model.weights, new_model.weights):
659
+ if not tf.math.reduce_all(p1 == p2):
660
+ models_equal = False
661
+ break
662
+ self.assertTrue(models_equal)
docs/transformers/tests/utils/test_modeling_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/tests/utils/test_offline.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import subprocess
16
+ import sys
17
+ import unittest
18
+
19
+ from transformers import BertConfig, BertModel, BertTokenizer, pipeline
20
+ from transformers.testing_utils import TestCasePlus, require_torch
21
+
22
+
23
+ class OfflineTests(TestCasePlus):
24
+ @require_torch
25
+ @unittest.skip("This test is failing on main") # TODO matt/ydshieh, this test needs to be fixed
26
+ def test_offline_mode(self):
27
+ # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
28
+ # `transformers` is loaded, and it's too late for inside pytest - so we are changing it
29
+ # while running an external program
30
+
31
+ # python one-liner segments
32
+
33
+ # this must be loaded before socket.socket is monkey-patched
34
+ load = """
35
+ from transformers import BertConfig, BertModel, BertTokenizer, pipeline
36
+ """
37
+
38
+ run = """
39
+ mname = "hf-internal-testing/tiny-random-bert"
40
+ BertConfig.from_pretrained(mname)
41
+ BertModel.from_pretrained(mname)
42
+ BertTokenizer.from_pretrained(mname)
43
+ pipe = pipeline(task="fill-mask", model=mname)
44
+ print("success")
45
+ """
46
+
47
+ mock = """
48
+ import socket
49
+ def offline_socket(*args, **kwargs): raise RuntimeError("Offline mode is enabled, we shouldn't access internet")
50
+ socket.socket = offline_socket
51
+ """
52
+
53
+ # Force fetching the files so that we can use the cache
54
+ mname = "hf-internal-testing/tiny-random-bert"
55
+ BertConfig.from_pretrained(mname)
56
+ BertModel.from_pretrained(mname)
57
+ BertTokenizer.from_pretrained(mname)
58
+ pipeline(task="fill-mask", model=mname)
59
+
60
+ # baseline - just load from_pretrained with normal network
61
+ # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
62
+ stdout, _ = self._execute_with_env(load, run, mock, TRANSFORMERS_OFFLINE="1")
63
+ self.assertIn("success", stdout)
64
+
65
+ @require_torch
66
+ def test_offline_mode_no_internet(self):
67
+ # python one-liner segments
68
+ # this must be loaded before socket.socket is monkey-patched
69
+ load = """
70
+ from transformers import BertConfig, BertModel, BertTokenizer, pipeline
71
+ """
72
+
73
+ run = """
74
+ mname = "hf-internal-testing/tiny-random-bert"
75
+ BertConfig.from_pretrained(mname)
76
+ BertModel.from_pretrained(mname)
77
+ BertTokenizer.from_pretrained(mname)
78
+ pipe = pipeline(task="fill-mask", model=mname)
79
+ print("success")
80
+ """
81
+
82
+ mock = """
83
+ import socket
84
+ def offline_socket(*args, **kwargs): raise socket.error("Faking flaky internet")
85
+ socket.socket = offline_socket
86
+ """
87
+
88
+ # Force fetching the files so that we can use the cache
89
+ mname = "hf-internal-testing/tiny-random-bert"
90
+ BertConfig.from_pretrained(mname)
91
+ BertModel.from_pretrained(mname)
92
+ BertTokenizer.from_pretrained(mname)
93
+ pipeline(task="fill-mask", model=mname)
94
+
95
+ # baseline - just load from_pretrained with normal network
96
+ # should succeed
97
+ stdout, _ = self._execute_with_env(load, run, mock)
98
+ self.assertIn("success", stdout)
99
+
100
+ @require_torch
101
+ def test_offline_mode_sharded_checkpoint(self):
102
+ # this test is a bit tricky since TRANSFORMERS_OFFLINE can only be changed before
103
+ # `transformers` is loaded, and it's too late for inside pytest - so we are changing it
104
+ # while running an external program
105
+
106
+ # python one-liner segments
107
+
108
+ # this must be loaded before socket.socket is monkey-patched
109
+ load = """
110
+ from transformers import BertConfig, BertModel, BertTokenizer
111
+ """
112
+
113
+ run = """
114
+ mname = "hf-internal-testing/tiny-random-bert-sharded"
115
+ BertConfig.from_pretrained(mname)
116
+ BertModel.from_pretrained(mname)
117
+ print("success")
118
+ """
119
+
120
+ mock = """
121
+ import socket
122
+ def offline_socket(*args, **kwargs): raise ValueError("Offline mode is enabled")
123
+ socket.socket = offline_socket
124
+ """
125
+
126
+ # baseline - just load from_pretrained with normal network
127
+ # should succeed
128
+ stdout, _ = self._execute_with_env(load, run)
129
+ self.assertIn("success", stdout)
130
+
131
+ # next emulate no network
132
+ # Doesn't fail anymore since the model is in the cache due to other tests, so commenting this.
133
+ # self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="0")
134
+
135
+ # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
136
+ stdout, _ = self._execute_with_env(load, mock, run, TRANSFORMERS_OFFLINE="1")
137
+ self.assertIn("success", stdout)
138
+
139
+ @require_torch
140
+ def test_offline_mode_pipeline_exception(self):
141
+ load = """
142
+ from transformers import pipeline
143
+ """
144
+ run = """
145
+ mname = "hf-internal-testing/tiny-random-bert"
146
+ pipe = pipeline(model=mname)
147
+ """
148
+
149
+ mock = """
150
+ import socket
151
+ def offline_socket(*args, **kwargs): raise socket.error("Offline mode is enabled")
152
+ socket.socket = offline_socket
153
+ """
154
+
155
+ _, stderr = self._execute_with_env(load, mock, run, should_fail=True, TRANSFORMERS_OFFLINE="1")
156
+ self.assertIn(
157
+ "You cannot infer task automatically within `pipeline` when using offline mode",
158
+ stderr.replace("\n", ""),
159
+ )
160
+
161
+ @require_torch
162
+ def test_offline_model_dynamic_model(self):
163
+ load = """
164
+ from transformers import AutoModel
165
+ """
166
+ run = """
167
+ mname = "hf-internal-testing/test_dynamic_model"
168
+ AutoModel.from_pretrained(mname, trust_remote_code=True)
169
+ print("success")
170
+ """
171
+
172
+ # baseline - just load from_pretrained with normal network
173
+ # should succeed
174
+ stdout, _ = self._execute_with_env(load, run)
175
+ self.assertIn("success", stdout)
176
+
177
+ # should succeed as TRANSFORMERS_OFFLINE=1 tells it to use local files
178
+ stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
179
+ self.assertIn("success", stdout)
180
+
181
+ def test_is_offline_mode(self):
182
+ """
183
+ Test `_is_offline_mode` helper (should respect both HF_HUB_OFFLINE and legacy TRANSFORMERS_OFFLINE env vars)
184
+ """
185
+ load = "from transformers.utils import is_offline_mode"
186
+ run = "print(is_offline_mode())"
187
+
188
+ stdout, _ = self._execute_with_env(load, run)
189
+ self.assertIn("False", stdout)
190
+
191
+ stdout, _ = self._execute_with_env(load, run, TRANSFORMERS_OFFLINE="1")
192
+ self.assertIn("True", stdout)
193
+
194
+ stdout, _ = self._execute_with_env(load, run, HF_HUB_OFFLINE="1")
195
+ self.assertIn("True", stdout)
196
+
197
+ def _execute_with_env(self, *commands: tuple[str, ...], should_fail: bool = False, **env) -> tuple[str, str]:
198
+ """Execute Python code with a given environment and return the stdout/stderr as strings.
199
+
200
+ If `should_fail=True`, the command is expected to fail. Otherwise, it should succeed.
201
+ Environment variables can be passed as keyword arguments.
202
+ """
203
+ # Build command
204
+ cmd = [sys.executable, "-c", "\n".join(commands)]
205
+
206
+ # Configure env
207
+ new_env = self.get_env()
208
+ new_env.update(env)
209
+
210
+ # Run command
211
+ result = subprocess.run(cmd, env=new_env, check=False, capture_output=True)
212
+
213
+ # Check execution
214
+ if should_fail:
215
+ self.assertNotEqual(result.returncode, 0, result.stderr)
216
+ else:
217
+ self.assertEqual(result.returncode, 0, result.stderr)
218
+
219
+ # Return output
220
+ return result.stdout.decode(), result.stderr.decode()
docs/transformers/tests/utils/test_processing_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+
17
+ import numpy as np
18
+
19
+ from transformers import is_torch_available, is_vision_available
20
+ from transformers.processing_utils import _validate_images_text_input_order
21
+ from transformers.testing_utils import require_torch, require_vision
22
+
23
+
24
+ if is_vision_available():
25
+ import PIL
26
+
27
+ if is_torch_available():
28
+ import torch
29
+
30
+
31
+ @require_vision
32
+ class ProcessingUtilTester(unittest.TestCase):
33
+ def test_validate_images_text_input_order(self):
34
+ # text string and PIL images inputs
35
+ images = PIL.Image.new("RGB", (224, 224))
36
+ text = "text"
37
+ # test correct text and images order
38
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
39
+ self.assertEqual(valid_images, images)
40
+ self.assertEqual(valid_text, text)
41
+ # test incorrect text and images order
42
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
43
+ self.assertEqual(valid_images, images)
44
+ self.assertEqual(valid_text, text)
45
+
46
+ # text list of string and numpy images inputs
47
+ images = np.random.rand(224, 224, 3)
48
+ text = ["text1", "text2"]
49
+ # test correct text and images order
50
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
51
+ self.assertTrue(np.array_equal(valid_images, images))
52
+ self.assertEqual(valid_text, text)
53
+ # test incorrect text and images order
54
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
55
+ self.assertTrue(np.array_equal(valid_images, images))
56
+ self.assertEqual(valid_text, text)
57
+
58
+ # text nested list of string and list of pil images inputs
59
+ images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
60
+ text = [["text1", "text2, text3"], ["text3", "text4"]]
61
+ # test correct text and images order
62
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
63
+ self.assertEqual(valid_images, images)
64
+ self.assertEqual(valid_text, text)
65
+ # test incorrect text and images order
66
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
67
+ self.assertEqual(valid_images, images)
68
+ self.assertEqual(valid_text, text)
69
+
70
+ # list of strings and list of numpy images inputs
71
+ images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
72
+ text = ["text1", "text2"]
73
+ # test correct text and images order
74
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
75
+ self.assertTrue(np.array_equal(valid_images[0], images[0]))
76
+ self.assertEqual(valid_text, text)
77
+ # test incorrect text and images order
78
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
79
+ self.assertTrue(np.array_equal(valid_images[0], images[0]))
80
+ self.assertEqual(valid_text, text)
81
+
82
+ # list of strings and list of url images inputs
83
+ images = ["https://url1", "https://url2"]
84
+ text = ["text1", "text2"]
85
+ # test correct text and images order
86
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
87
+ self.assertEqual(valid_images, images)
88
+ self.assertEqual(valid_text, text)
89
+ # test incorrect text and images order
90
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
91
+ self.assertEqual(valid_images, images)
92
+ self.assertEqual(valid_text, text)
93
+
94
+ # list of strings and nested list of numpy images inputs
95
+ images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
96
+ text = ["text1", "text2"]
97
+ # test correct text and images order
98
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
99
+ self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
100
+ self.assertEqual(valid_text, text)
101
+ # test incorrect text and images order
102
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
103
+ self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
104
+ self.assertEqual(valid_text, text)
105
+
106
+ # nested list of strings and nested list of PIL images inputs
107
+ images = [
108
+ [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
109
+ [PIL.Image.new("RGB", (224, 224))],
110
+ ]
111
+ text = [["text1", "text2, text3"], ["text3", "text4"]]
112
+ # test correct text and images order
113
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
114
+ self.assertEqual(valid_images, images)
115
+ self.assertEqual(valid_text, text)
116
+ # test incorrect text and images order
117
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
118
+ self.assertEqual(valid_images, images)
119
+ self.assertEqual(valid_text, text)
120
+
121
+ # None images
122
+ images = None
123
+ text = "text"
124
+ # test correct text and images order
125
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
126
+ self.assertEqual(images, None)
127
+ self.assertEqual(text, text)
128
+ # test incorrect text and images order
129
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
130
+ self.assertEqual(images, None)
131
+ self.assertEqual(text, text)
132
+
133
+ # None text
134
+ images = PIL.Image.new("RGB", (224, 224))
135
+ text = None
136
+ # test correct text and images order
137
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
138
+ self.assertEqual(images, images)
139
+ self.assertEqual(text, None)
140
+ # test incorrect text and images order
141
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
142
+ self.assertEqual(images, images)
143
+ self.assertEqual(text, None)
144
+
145
+ # incorrect inputs
146
+ images = "text"
147
+ text = "text"
148
+ with self.assertRaises(ValueError):
149
+ _validate_images_text_input_order(images=images, text=text)
150
+
151
+ @require_torch
152
+ def test_validate_images_text_input_order_torch(self):
153
+ # text string and torch images inputs
154
+ images = torch.rand(224, 224, 3)
155
+ text = "text"
156
+ # test correct text and images order
157
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
158
+ self.assertTrue(torch.equal(valid_images, images))
159
+ self.assertEqual(valid_text, text)
160
+ # test incorrect text and images order
161
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
162
+ self.assertTrue(torch.equal(valid_images, images))
163
+ self.assertEqual(valid_text, text)
164
+
165
+ # text list of string and list of torch images inputs
166
+ images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
167
+ text = ["text1", "text2"]
168
+ # test correct text and images order
169
+ valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
170
+ self.assertTrue(torch.equal(valid_images[0], images[0]))
171
+ self.assertEqual(valid_text, text)
172
+ # test incorrect text and images order
173
+ valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
174
+ self.assertTrue(torch.equal(valid_images[0], images[0]))
175
+ self.assertEqual(valid_text, text)
docs/transformers/tests/utils/test_skip_decorators.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019-present, the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ #
16
+ #
17
+ # this test validates that we can stack skip decorators in groups and whether
18
+ # they work correctly with other decorators
19
+ #
20
+ # since the decorators have already built their decision params (like checking
21
+ # env[], we can't mock the env and test each of the combinations), so ideally
22
+ # the following 4 should be run. But since we have different CI jobs running
23
+ # different configs, all combinations should get covered
24
+ #
25
+ # RUN_SLOW=1 pytest -rA tests/test_skip_decorators.py
26
+ # RUN_SLOW=1 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
27
+ # RUN_SLOW=0 pytest -rA tests/test_skip_decorators.py
28
+ # RUN_SLOW=0 CUDA_VISIBLE_DEVICES="" pytest -rA tests/test_skip_decorators.py
29
+
30
+ import os
31
+ import unittest
32
+
33
+ import pytest
34
+ from parameterized import parameterized
35
+
36
+ from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
37
+
38
+
39
+ # skipping in unittest tests
40
+
41
+ params = [(1,)]
42
+
43
+
44
+ # test that we can stack our skip decorators with 3rd party decorators
45
+ def check_slow():
46
+ run_slow = bool(os.getenv("RUN_SLOW", 0))
47
+ if run_slow:
48
+ assert True
49
+ else:
50
+ assert False, "should have been skipped"
51
+
52
+
53
+ # test that we can stack our skip decorators
54
+ def check_slow_torch_cuda():
55
+ run_slow = bool(os.getenv("RUN_SLOW", 0))
56
+ if run_slow and torch_device == "cuda":
57
+ assert True
58
+ else:
59
+ assert False, "should have been skipped"
60
+
61
+
62
+ @require_torch
63
+ class SkipTester(unittest.TestCase):
64
+ @slow
65
+ @require_torch_gpu
66
+ def test_2_skips_slow_first(self):
67
+ check_slow_torch_cuda()
68
+
69
+ @require_torch_gpu
70
+ @slow
71
+ def test_2_skips_slow_last(self):
72
+ check_slow_torch_cuda()
73
+
74
+ # The combination of any skip decorator, followed by parameterized fails to skip the tests
75
+ # 1. @slow manages to correctly skip `test_param_slow_first`
76
+ # 2. but then `parameterized` creates new tests, with a unique name for each parameter groups.
77
+ # It has no idea that they are to be skipped and so they all run, ignoring @slow
78
+ # Therefore skip decorators must come after `parameterized`
79
+ #
80
+ # @slow
81
+ # @parameterized.expand(params)
82
+ # def test_param_slow_first(self, param=None):
83
+ # check_slow()
84
+
85
+ # This works as expected:
86
+ # 1. `parameterized` creates new tests with unique names
87
+ # 2. each of them gets an opportunity to be skipped
88
+ @parameterized.expand(params)
89
+ @slow
90
+ def test_param_slow_last(self, param=None):
91
+ check_slow()
92
+
93
+
94
+ # skipping in non-unittest tests
95
+ # no problem at all here
96
+
97
+
98
+ @slow
99
+ @require_torch_gpu
100
+ def test_pytest_2_skips_slow_first():
101
+ check_slow_torch_cuda()
102
+
103
+
104
+ @require_torch_gpu
105
+ @slow
106
+ def test_pytest_2_skips_slow_last():
107
+ check_slow_torch_cuda()
108
+
109
+
110
+ @slow
111
+ @pytest.mark.parametrize("param", [1])
112
+ def test_pytest_param_slow_first(param):
113
+ check_slow()
114
+
115
+
116
+ @pytest.mark.parametrize("param", [1])
117
+ @slow
118
+ def test_pytest_param_slow_last(param):
119
+ check_slow()
docs/transformers/tests/utils/test_tokenization_utils.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 HuggingFace Inc.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import sys
17
+ import tempfile
18
+ import unittest
19
+ import unittest.mock as mock
20
+ from pathlib import Path
21
+
22
+ from huggingface_hub import HfFolder
23
+ from huggingface_hub.file_download import http_get
24
+ from requests.exceptions import HTTPError
25
+
26
+ from transformers import (
27
+ AlbertTokenizer,
28
+ AutoTokenizer,
29
+ BertTokenizer,
30
+ BertTokenizerFast,
31
+ GPT2TokenizerFast,
32
+ is_tokenizers_available,
33
+ )
34
+ from transformers.testing_utils import TOKEN, TemporaryHubRepo, is_staging_test, require_tokenizers
35
+ from transformers.tokenization_utils import ExtensionsTrie, Trie
36
+
37
+
38
+ sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
39
+
40
+ from test_module.custom_tokenization import CustomTokenizer # noqa E402
41
+
42
+
43
+ if is_tokenizers_available():
44
+ from test_module.custom_tokenization_fast import CustomTokenizerFast
45
+
46
+
47
+ class TokenizerUtilTester(unittest.TestCase):
48
+ def test_cached_files_are_used_when_internet_is_down(self):
49
+ # A mock response for an HTTP head request to emulate server down
50
+ response_mock = mock.Mock()
51
+ response_mock.status_code = 500
52
+ response_mock.headers = {}
53
+ response_mock.raise_for_status.side_effect = HTTPError
54
+ response_mock.json.return_value = {}
55
+
56
+ # Download this model to make sure it's in the cache.
57
+ _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
58
+
59
+ # Under the mock environment we get a 500 error when trying to reach the tokenizer.
60
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
61
+ _ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
62
+ # This check we did call the fake head request
63
+ mock_head.assert_called()
64
+
65
+ @require_tokenizers
66
+ def test_cached_files_are_used_when_internet_is_down_missing_files(self):
67
+ # A mock response for an HTTP head request to emulate server down
68
+ response_mock = mock.Mock()
69
+ response_mock.status_code = 500
70
+ response_mock.headers = {}
71
+ response_mock.raise_for_status.side_effect = HTTPError
72
+ response_mock.json.return_value = {}
73
+
74
+ # Download this model to make sure it's in the cache.
75
+ _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
76
+
77
+ # Under the mock environment we get a 500 error when trying to reach the tokenizer.
78
+ with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
79
+ _ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
80
+ # This check we did call the fake head request
81
+ mock_head.assert_called()
82
+
83
+ def test_legacy_load_from_one_file(self):
84
+ # This test is for deprecated behavior and can be removed in v5
85
+ try:
86
+ tmp_file = tempfile.NamedTemporaryFile(delete=False).name
87
+ with open(tmp_file, "wb") as f:
88
+ http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)
89
+
90
+ _ = AlbertTokenizer.from_pretrained(tmp_file)
91
+ finally:
92
+ os.remove(tmp_file)
93
+
94
+ # Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
95
+ # the current folder and have the right name.
96
+ if os.path.isfile("tokenizer.json"):
97
+ # We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
98
+ self.skipTest(reason="Skipping test as there is a `tokenizer.json` file in the current folder.")
99
+ try:
100
+ with open("tokenizer.json", "wb") as f:
101
+ http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
102
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
103
+ # The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000
104
+ self.assertEqual(tokenizer.vocab_size, 1000)
105
+ # Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
106
+
107
+ finally:
108
+ os.remove("tokenizer.json")
109
+
110
+
111
+ @is_staging_test
112
+ class TokenizerPushToHubTester(unittest.TestCase):
113
+ vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
114
+
115
+ @classmethod
116
+ def setUpClass(cls):
117
+ cls._token = TOKEN
118
+ HfFolder.save_token(TOKEN)
119
+
120
+ def test_push_to_hub(self):
121
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
122
+ with tempfile.TemporaryDirectory() as tmp_dir:
123
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
124
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
125
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
126
+ tokenizer = BertTokenizer(vocab_file)
127
+
128
+ tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
129
+ new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
130
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
131
+
132
+ def test_push_to_hub_via_save_pretrained(self):
133
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
134
+ with tempfile.TemporaryDirectory() as tmp_dir:
135
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
136
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
137
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
138
+ tokenizer = BertTokenizer(vocab_file)
139
+
140
+ # Push to hub via save_pretrained
141
+ tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
142
+
143
+ new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
144
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
145
+
146
+ def test_push_to_hub_in_organization(self):
147
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
148
+ with tempfile.TemporaryDirectory() as tmp_dir:
149
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
150
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
151
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
152
+ tokenizer = BertTokenizer(vocab_file)
153
+
154
+ tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
155
+ new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
156
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
157
+
158
+ def test_push_to_hub_in_organization_via_save_pretrained(self):
159
+ with TemporaryHubRepo(namespace="valid_org", token=self._token) as tmp_repo:
160
+ with tempfile.TemporaryDirectory() as tmp_dir:
161
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
162
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
163
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
164
+ tokenizer = BertTokenizer(vocab_file)
165
+
166
+ # Push to hub via save_pretrained
167
+ tokenizer.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, push_to_hub=True, token=self._token)
168
+
169
+ new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
170
+ self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
171
+
172
+ @require_tokenizers
173
+ def test_push_to_hub_dynamic_tokenizer(self):
174
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
175
+ CustomTokenizer.register_for_auto_class()
176
+ with tempfile.TemporaryDirectory() as tmp_dir:
177
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
178
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
179
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
180
+ tokenizer = CustomTokenizer(vocab_file)
181
+
182
+ # No fast custom tokenizer
183
+ tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
184
+
185
+ tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
186
+ # Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
187
+ self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
188
+
189
+ @require_tokenizers
190
+ def test_push_to_hub_dynamic_tokenizer_with_both_slow_and_fast_classes(self):
191
+ with TemporaryHubRepo(token=self._token) as tmp_repo:
192
+ CustomTokenizer.register_for_auto_class()
193
+
194
+ # Fast and slow custom tokenizer
195
+ CustomTokenizerFast.register_for_auto_class()
196
+
197
+ with tempfile.TemporaryDirectory() as tmp_dir:
198
+ vocab_file = os.path.join(tmp_dir, "vocab.txt")
199
+ with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
200
+ vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
201
+
202
+ bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
203
+ bert_tokenizer.save_pretrained(tmp_dir)
204
+ tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
205
+
206
+ tokenizer.push_to_hub(tmp_repo.repo_id, token=self._token)
207
+
208
+ tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
209
+ # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
210
+ self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
211
+ tokenizer = AutoTokenizer.from_pretrained(tmp_repo.repo_id, use_fast=False, trust_remote_code=True)
212
+ # Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
213
+ self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
214
+
215
+
216
+ class TrieTest(unittest.TestCase):
217
+ def test_trie(self):
218
+ trie = Trie()
219
+ trie.add("Hello 友達")
220
+ self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
221
+ trie.add("Hello")
222
+ self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
223
+
224
+ def test_trie_split(self):
225
+ trie = Trie()
226
+ self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
227
+ trie.add("[CLS]")
228
+ trie.add("extra_id_1")
229
+ trie.add("extra_id_100")
230
+ self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
231
+
232
+ def test_trie_single(self):
233
+ trie = Trie()
234
+ trie.add("A")
235
+ self.assertEqual(trie.split("ABC"), ["A", "BC"])
236
+ self.assertEqual(trie.split("BCA"), ["BC", "A"])
237
+
238
+ def test_trie_final(self):
239
+ trie = Trie()
240
+ trie.add("TOKEN]")
241
+ trie.add("[SPECIAL_TOKEN]")
242
+ self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
243
+
244
+ def test_trie_subtokens(self):
245
+ trie = Trie()
246
+ trie.add("A")
247
+ trie.add("P")
248
+ trie.add("[SPECIAL_TOKEN]")
249
+ self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
250
+
251
+ def test_trie_suffix_tokens(self):
252
+ trie = Trie()
253
+ trie.add("AB")
254
+ trie.add("B")
255
+ trie.add("C")
256
+ self.assertEqual(trie.split("ABC"), ["AB", "C"])
257
+
258
+ def test_trie_skip(self):
259
+ trie = Trie()
260
+ trie.add("ABC")
261
+ trie.add("B")
262
+ trie.add("CD")
263
+ self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
264
+
265
+ def test_cut_text_hardening(self):
266
+ # Even if the offsets are wrong, we necessarily output correct string
267
+ # parts.
268
+ trie = Trie()
269
+ parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
270
+ self.assertEqual(parts, ["AB", "C"])
271
+
272
+
273
+ class ExtensionsTrieTest(unittest.TestCase):
274
+ def test_extensions(self):
275
+ # Test searching by prefix
276
+ trie = ExtensionsTrie()
277
+ trie.add("foo")
278
+ trie.add("food")
279
+ trie.add("foodie")
280
+ trie.add("helium")
281
+ self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
282
+ self.assertEqual(trie.extensions("helium"), ["helium"])
283
+
284
+ def test_empty_prefix(self):
285
+ trie = ExtensionsTrie()
286
+ # Test searching with an empty prefix returns all values
287
+ trie.add("hello")
288
+ trie.add("bye")
289
+ self.assertEqual(trie.extensions(""), ["hello", "bye"])
290
+
291
+ def test_no_extension_match(self):
292
+ trie = ExtensionsTrie()
293
+ # Test searching for a prefix that doesn't match any key
294
+ values = trie.extensions("unknown")
295
+
296
+ self.assertEqual(len(values), 0)
297
+
298
+ def test_update_value(self):
299
+ trie = ExtensionsTrie()
300
+ # Test updating the value of an existing key
301
+ trie.add("hi")
302
+ trie.add("hi")
303
+ self.assertEqual(trie.extensions("hi"), ["hi"])
docs/transformers/tests/utils/test_versions_utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib.metadata
16
+ import sys
17
+
18
+ from transformers.testing_utils import TestCasePlus
19
+ from transformers.utils.versions import require_version, require_version_core
20
+
21
+
22
+ numpy_ver = importlib.metadata.version("numpy")
23
+ python_ver = ".".join([str(x) for x in sys.version_info[:3]])
24
+
25
+
26
+ class DependencyVersionCheckTest(TestCasePlus):
27
+ def test_core(self):
28
+ # lt + different version strings
29
+ require_version_core("numpy<1000.4.5")
30
+ require_version_core("numpy<1000.4")
31
+ require_version_core("numpy<1000")
32
+
33
+ # le
34
+ require_version_core("numpy<=1000.4.5")
35
+ require_version_core(f"numpy<={numpy_ver}")
36
+
37
+ # eq
38
+ require_version_core(f"numpy=={numpy_ver}")
39
+
40
+ # ne
41
+ require_version_core("numpy!=1000.4.5")
42
+
43
+ # ge
44
+ require_version_core("numpy>=1.0")
45
+ require_version_core("numpy>=1.0.0")
46
+ require_version_core(f"numpy>={numpy_ver}")
47
+
48
+ # gt
49
+ require_version_core("numpy>1.0.0")
50
+
51
+ # mix
52
+ require_version_core("numpy>1.0.0,<1000")
53
+
54
+ # requirement w/o version
55
+ require_version_core("numpy")
56
+
57
+ # unmet requirements due to version conflict
58
+ for req in ["numpy==1.0.0", "numpy>=1000.0.0", f"numpy<{numpy_ver}"]:
59
+ try:
60
+ require_version_core(req)
61
+ except ImportError as e:
62
+ self.assertIn(f"{req} is required", str(e))
63
+ self.assertIn("but found", str(e))
64
+
65
+ # unmet requirements due to missing module
66
+ for req in ["numpipypie>1", "numpipypie2"]:
67
+ try:
68
+ require_version_core(req)
69
+ except importlib.metadata.PackageNotFoundError as e:
70
+ self.assertIn(f"The '{req}' distribution was not found and is required by this application", str(e))
71
+ self.assertIn("Try: `pip install transformers -U`", str(e))
72
+
73
+ # bogus requirements formats:
74
+ # 1. whole thing
75
+ for req in ["numpy??1.0.0", "numpy1.0.0"]:
76
+ try:
77
+ require_version_core(req)
78
+ except ValueError as e:
79
+ self.assertIn("requirement needs to be in the pip package format", str(e))
80
+ # 2. only operators
81
+ for req in ["numpy=1.0.0", "numpy == 1.00", "numpy<>1.0.0", "numpy><1.00", "numpy>>1.0.0"]:
82
+ try:
83
+ require_version_core(req)
84
+ except ValueError as e:
85
+ self.assertIn("need one of ", str(e))
86
+
87
+ def test_python(self):
88
+ # matching requirement
89
+ require_version("python>=3.9.0")
90
+
91
+ # not matching requirements
92
+ for req in ["python>9.9.9", "python<3.0.0"]:
93
+ try:
94
+ require_version_core(req)
95
+ except ImportError as e:
96
+ self.assertIn(f"{req} is required", str(e))
97
+ self.assertIn(f"but found python=={python_ver}", str(e))
docs/transformers/tests/utils/tiny_model_summary.json ADDED
The diff for this file is too large to render. See raw diff
 
docs/transformers/utils/add_pipeline_model_mapping_to_test.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """A script to add and/or update the attribute `pipeline_model_mapping` in model test files.
16
+
17
+ This script will be (mostly) used in the following 2 situations:
18
+
19
+ - run within a (scheduled) CI job to:
20
+ - check if model test files in the library have updated `pipeline_model_mapping`,
21
+ - and/or update test files and (possibly) open a GitHub pull request automatically
22
+ - being run by a `transformers` member to quickly check and update some particular test file(s)
23
+
24
+ This script is **NOT** intended to be run (manually) by community contributors.
25
+ """
26
+
27
+ import argparse
28
+ import glob
29
+ import inspect
30
+ import os
31
+ import re
32
+ import unittest
33
+
34
+ from get_test_info import get_test_classes
35
+
36
+ from tests.test_pipeline_mixin import pipeline_test_mapping
37
+
38
+
39
+ PIPELINE_TEST_MAPPING = {}
40
+ for task, _ in pipeline_test_mapping.items():
41
+ PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None}
42
+
43
+
44
+ # DO **NOT** add item to this set (unless the reason is approved)
45
+ TEST_FILE_TO_IGNORE = {
46
+ "tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py`
47
+ }
48
+
49
+
50
+ def get_framework(test_class):
51
+ """Infer the framework from the test class `test_class`."""
52
+
53
+ if "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]:
54
+ return "pt"
55
+ elif "TFModelTesterMixin" in [x.__name__ for x in test_class.__bases__]:
56
+ return "tf"
57
+ elif "FlaxModelTesterMixin" in [x.__name__ for x in test_class.__bases__]:
58
+ return "flax"
59
+ else:
60
+ return None
61
+
62
+
63
+ def get_mapping_for_task(task, framework):
64
+ """Get mappings defined in `XXXPipelineTests` for the task `task`."""
65
+ # Use the cached results
66
+ if PIPELINE_TEST_MAPPING[task].get(framework, None) is not None:
67
+ return PIPELINE_TEST_MAPPING[task][framework]
68
+
69
+ pipeline_test_class = pipeline_test_mapping[task]["test"]
70
+ mapping = None
71
+
72
+ if framework == "pt":
73
+ mapping = getattr(pipeline_test_class, "model_mapping", None)
74
+ elif framework == "tf":
75
+ mapping = getattr(pipeline_test_class, "tf_model_mapping", None)
76
+
77
+ if mapping is not None:
78
+ mapping = dict(mapping.items())
79
+
80
+ # cache the results
81
+ PIPELINE_TEST_MAPPING[task][framework] = mapping
82
+ return mapping
83
+
84
+
85
+ def get_model_for_pipeline_test(test_class, task):
86
+ """Get the model architecture(s) related to the test class `test_class` for a pipeline `task`."""
87
+ framework = get_framework(test_class)
88
+ if framework is None:
89
+ return None
90
+ mapping = get_mapping_for_task(task, framework)
91
+ if mapping is None:
92
+ return None
93
+
94
+ config_classes = list({model_class.config_class for model_class in test_class.all_model_classes})
95
+ if len(config_classes) != 1:
96
+ raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.")
97
+
98
+ # This could be a list/tuple of model classes, but it's rare.
99
+ model_class = mapping.get(config_classes[0], None)
100
+ if isinstance(model_class, (tuple, list)):
101
+ model_class = sorted(model_class, key=lambda x: x.__name__)
102
+
103
+ return model_class
104
+
105
+
106
+ def get_pipeline_model_mapping(test_class):
107
+ """Get `pipeline_model_mapping` for `test_class`."""
108
+ mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping]
109
+ mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0])
110
+
111
+ return dict(mapping)
112
+
113
+
114
+ def get_pipeline_model_mapping_string(test_class):
115
+ """Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file).
116
+
117
+ This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully.
118
+ """
119
+ framework = get_framework(test_class)
120
+ if framework == "pt":
121
+ framework = "torch"
122
+ default_value = "{}"
123
+
124
+ mapping = get_pipeline_model_mapping(test_class)
125
+ if len(mapping) == 0:
126
+ return ""
127
+
128
+ texts = []
129
+ for task, model_classes in mapping.items():
130
+ if isinstance(model_classes, (tuple, list)):
131
+ # A list/tuple of model classes
132
+ value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")"
133
+ else:
134
+ # A single model class
135
+ value = model_classes.__name__
136
+ texts.append(f'"{task}": {value}')
137
+ text = "{" + ", ".join(texts) + "}"
138
+ text = f"pipeline_model_mapping = {text} if is_{framework}_available() else {default_value}"
139
+
140
+ return text
141
+
142
+
143
+ def is_valid_test_class(test_class):
144
+ """Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`."""
145
+ base_class_names = {"ModelTesterMixin", "TFModelTesterMixin", "FlaxModelTesterMixin"}
146
+ if not issubclass(test_class, unittest.TestCase):
147
+ return False
148
+ return len(base_class_names.intersection([x.__name__ for x in test_class.__bases__])) > 0
149
+
150
+
151
+ def find_test_class(test_file):
152
+ """Find a test class in `test_file` to which we will add `pipeline_model_mapping`."""
153
+ test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)]
154
+
155
+ target_test_class = None
156
+ for test_class in test_classes:
157
+ # If a test class has defined `pipeline_model_mapping`, let's take it
158
+ if getattr(test_class, "pipeline_model_mapping", None) is not None:
159
+ target_test_class = test_class
160
+ break
161
+ # Take the test class with the shortest name (just a heuristic)
162
+ if target_test_class is None and len(test_classes) > 0:
163
+ target_test_class = sorted(test_classes, key=lambda x: (len(x.__name__), x.__name__))[0]
164
+
165
+ return target_test_class
166
+
167
+
168
+ def find_block_ending(lines, start_idx, indent_level):
169
+ end_idx = start_idx
170
+ for idx, line in enumerate(lines[start_idx:]):
171
+ indent = len(line) - len(line.lstrip())
172
+ if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"):
173
+ end_idx = start_idx + idx
174
+ elif idx > 0 and indent <= indent_level:
175
+ # Outside the definition block of `pipeline_model_mapping`
176
+ break
177
+
178
+ return end_idx
179
+
180
+
181
+ def add_pipeline_model_mapping(test_class, overwrite=False):
182
+ """Add `pipeline_model_mapping` to `test_class`."""
183
+ if getattr(test_class, "pipeline_model_mapping", None) is not None:
184
+ if not overwrite:
185
+ return "", -1
186
+
187
+ line_to_add = get_pipeline_model_mapping_string(test_class)
188
+ if len(line_to_add) == 0:
189
+ return "", -1
190
+ line_to_add = line_to_add + "\n"
191
+
192
+ # The code defined the class `test_class`
193
+ class_lines, class_start_line_no = inspect.getsourcelines(test_class)
194
+ # `inspect` gives the code for an object, including decorator(s) if any.
195
+ # We (only) need the exact line of the class definition.
196
+ for idx, line in enumerate(class_lines):
197
+ if line.lstrip().startswith("class "):
198
+ class_lines = class_lines[idx:]
199
+ class_start_line_no += idx
200
+ break
201
+ class_end_line_no = class_start_line_no + len(class_lines) - 1
202
+
203
+ # The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or
204
+ # `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last
205
+ # block that appears in a `test_class`.
206
+ start_idx = None
207
+ # The indent level of the line at `class_lines[start_idx]` (if defined)
208
+ indent_level = 0
209
+ # To record if `pipeline_model_mapping` is found in `test_class`.
210
+ def_line = None
211
+ for idx, line in enumerate(class_lines):
212
+ if line.strip().startswith("all_model_classes = "):
213
+ indent_level = len(line) - len(line.lstrip())
214
+ start_idx = idx
215
+ elif line.strip().startswith("all_generative_model_classes = "):
216
+ indent_level = len(line) - len(line.lstrip())
217
+ start_idx = idx
218
+ elif line.strip().startswith("pipeline_model_mapping = "):
219
+ indent_level = len(line) - len(line.lstrip())
220
+ start_idx = idx
221
+ def_line = line
222
+ break
223
+
224
+ if start_idx is None:
225
+ return "", -1
226
+ # Find the ending index (inclusive) of the above found block.
227
+ end_idx = find_block_ending(class_lines, start_idx, indent_level)
228
+
229
+ # Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use
230
+ # `is_timm_available()` instead of `is_torch_available()`.
231
+ # Keep leading and trailing whitespaces
232
+ r = re.compile(r"\s(is_\S+?_available\(\))\s")
233
+ for line in class_lines[start_idx : end_idx + 1]:
234
+ backend_condition = r.search(line)
235
+ if backend_condition is not None:
236
+ # replace the leading and trailing whitespaces to the space character " ".
237
+ target = " " + backend_condition[0][1:-1] + " "
238
+ line_to_add = r.sub(target, line_to_add)
239
+ break
240
+
241
+ if def_line is None:
242
+ # `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of
243
+ # `all_model_classes` or `all_generative_model_classes`.
244
+ target_idx = end_idx
245
+ else:
246
+ # `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index.
247
+ target_idx = start_idx - 1
248
+ # mark the lines of the currently existing `pipeline_model_mapping` to be removed.
249
+ for idx in range(start_idx, end_idx + 1):
250
+ # These lines are going to be removed before writing to the test file.
251
+ class_lines[idx] = None # noqa
252
+
253
+ # Make sure the test class is a subclass of `PipelineTesterMixin`.
254
+ parent_classes = [x.__name__ for x in test_class.__bases__]
255
+ if "PipelineTesterMixin" not in parent_classes:
256
+ # Put `PipelineTesterMixin` just before `unittest.TestCase`
257
+ _parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"]
258
+ if "TestCase" in parent_classes:
259
+ # Here we **assume** the original string is always with `unittest.TestCase`.
260
+ _parent_classes.append("unittest.TestCase")
261
+ parent_classes = ", ".join(_parent_classes)
262
+ for idx, line in enumerate(class_lines):
263
+ # Find the ending of the declaration of `test_class`
264
+ if line.strip().endswith("):"):
265
+ # mark the lines of the declaration of `test_class` to be removed
266
+ for _idx in range(idx + 1):
267
+ class_lines[_idx] = None # noqa
268
+ break
269
+ # Add the new, one-line, class declaration for `test_class`
270
+ class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n"
271
+
272
+ # Add indentation
273
+ line_to_add = " " * indent_level + line_to_add
274
+ # Insert `pipeline_model_mapping` to `class_lines`.
275
+ # (The line at `target_idx` should be kept by definition!)
276
+ class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :]
277
+ # Remove the lines that are marked to be removed
278
+ class_lines = [x for x in class_lines if x is not None]
279
+
280
+ # Move from test class to module (in order to write to the test file)
281
+ module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0]
282
+ # Be careful with the 1-off between line numbers and array indices
283
+ module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:]
284
+ code = "".join(module_lines)
285
+
286
+ moddule_file = inspect.getsourcefile(test_class)
287
+ with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp:
288
+ fp.write(code)
289
+
290
+ return line_to_add
291
+
292
+
293
+ def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False):
294
+ """Add `pipeline_model_mapping` to `test_file`."""
295
+ test_class = find_test_class(test_file)
296
+ if test_class:
297
+ add_pipeline_model_mapping(test_class, overwrite=overwrite)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ parser = argparse.ArgumentParser()
302
+ parser.add_argument(
303
+ "--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory."
304
+ )
305
+ parser.add_argument(
306
+ "--all",
307
+ action="store_true",
308
+ help="If to check and modify all test files.",
309
+ )
310
+ parser.add_argument(
311
+ "--overwrite",
312
+ action="store_true",
313
+ help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.",
314
+ )
315
+ args = parser.parse_args()
316
+
317
+ if not args.all and not args.test_file:
318
+ raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.")
319
+ elif args.all and args.test_file:
320
+ raise ValueError("Only one of `--test_file` and `--all` could be specified.")
321
+
322
+ test_files = []
323
+ if args.test_file:
324
+ test_files = [args.test_file]
325
+ else:
326
+ pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
327
+ for test_file in glob.glob(pattern):
328
+ # `Flax` is not concerned at this moment
329
+ if not test_file.startswith("test_modeling_flax_"):
330
+ test_files.append(test_file)
331
+
332
+ for test_file in test_files:
333
+ if test_file in TEST_FILE_TO_IGNORE:
334
+ print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.")
335
+ continue
336
+ add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)
docs/transformers/utils/check_bad_commit.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import argparse
18
+ import json
19
+ import os
20
+ import re
21
+ import subprocess
22
+
23
+ import requests
24
+
25
+
26
+ def create_script(target_test):
27
+ """Create a python script to be run by `git bisect run` to determine if `target_test` passes or fails.
28
+ If a test is not found in a commit, the script with exit code `0` (i.e. `Success`).
29
+
30
+ Args:
31
+ target_test (`str`): The test to check.
32
+
33
+ Returns:
34
+ `str`: The script to be run by `git bisect run`.
35
+ """
36
+
37
+ script = f"""
38
+ import os
39
+ import subprocess
40
+
41
+ result = subprocess.run(
42
+ ["python3", "-m", "pytest", "-v", f"{target_test}"],
43
+ capture_output = True,
44
+ text=True,
45
+ )
46
+ print(result.stdout)
47
+
48
+ if len(result.stderr) > 0:
49
+ if "ERROR: file or directory not found: " in result.stderr:
50
+ print("test file or directory not found in this commit")
51
+ exit(0)
52
+ elif "ERROR: not found: " in result.stderr:
53
+ print("test not found in this commit")
54
+ exit(0)
55
+ else:
56
+ print(f"pytest failed to run: {{result.stderr}}")
57
+ exit(-1)
58
+ elif f"{target_test} FAILED" in result.stdout:
59
+ print("test failed")
60
+ exit(2)
61
+
62
+ exit(0)
63
+ """
64
+
65
+ with open("target_script.py", "w") as fp:
66
+ fp.write(script.strip())
67
+
68
+
69
+ def find_bad_commit(target_test, start_commit, end_commit):
70
+ """Find (backward) the earliest commit between `start_commit` and `end_commit` at which `target_test` fails.
71
+
72
+ Args:
73
+ target_test (`str`): The test to check.
74
+ start_commit (`str`): The latest commit.
75
+ end_commit (`str`): The earliest commit.
76
+
77
+ Returns:
78
+ `str`: The earliest commit at which `target_test` fails.
79
+ """
80
+
81
+ if start_commit == end_commit:
82
+ return start_commit
83
+
84
+ create_script(target_test=target_test)
85
+
86
+ bash = f"""
87
+ git bisect reset
88
+ git bisect start {start_commit} {end_commit}
89
+ git bisect run python3 target_script.py
90
+ """
91
+
92
+ with open("run_git_bisect.sh", "w") as fp:
93
+ fp.write(bash.strip())
94
+
95
+ result = subprocess.run(
96
+ ["bash", "run_git_bisect.sh"],
97
+ capture_output=True,
98
+ text=True,
99
+ )
100
+ print(result.stdout)
101
+
102
+ if "error: bisect run failed" in result.stderr:
103
+ index = result.stderr.find("error: bisect run failed")
104
+ bash_error = result.stderr[index:]
105
+
106
+ error_msg = f"Error when running git bisect:\nbash error: {bash_error}"
107
+
108
+ pattern = "pytest failed to run: .+"
109
+ pytest_errors = re.findall(pattern, result.stdout)
110
+ if len(pytest_errors) > 0:
111
+ pytest_error = pytest_errors[0]
112
+ index = pytest_error.find("pytest failed to run: ")
113
+ index += len("pytest failed to run: ")
114
+ pytest_error = pytest_error[index:]
115
+ error_msg += f"pytest error: {pytest_error}"
116
+
117
+ raise ValueError(error_msg)
118
+
119
+ pattern = r"(.+) is the first bad commit"
120
+ commits = re.findall(pattern, result.stdout)
121
+
122
+ bad_commit = None
123
+ if len(commits) > 0:
124
+ bad_commit = commits[0]
125
+
126
+ print(f"Between `start_commit` {start_commit} and `end_commit` {end_commit}")
127
+ print(f"bad_commit: {bad_commit}\n")
128
+
129
+ return bad_commit
130
+
131
+
132
+ def get_commit_info(commit):
133
+ """Get information for a commit via `api.github.com`."""
134
+ pr_number = None
135
+ author = None
136
+ merged_author = None
137
+
138
+ url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}/pulls"
139
+ pr_info_for_commit = requests.get(url).json()
140
+
141
+ if len(pr_info_for_commit) > 0:
142
+ pr_number = pr_info_for_commit[0]["number"]
143
+
144
+ url = f"https://api.github.com/repos/huggingface/transformers/pulls/{pr_number}"
145
+ pr_for_commit = requests.get(url).json()
146
+ author = pr_for_commit["user"]["login"]
147
+ merged_author = pr_for_commit["merged_by"]["login"]
148
+
149
+ if author is None:
150
+ url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}"
151
+ commit_info = requests.get(url).json()
152
+ author = commit_info["author"]["login"]
153
+
154
+ return {"commit": commit, "pr_number": pr_number, "author": author, "merged_by": merged_author}
155
+
156
+
157
+ if __name__ == "__main__":
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--start_commit", type=str, required=True, help="The latest commit hash to check.")
160
+ parser.add_argument("--end_commit", type=str, required=True, help="The earliest commit hash to check.")
161
+ parser.add_argument("--test", type=str, help="The test to check.")
162
+ parser.add_argument("--file", type=str, help="The report file.")
163
+ parser.add_argument("--output_file", type=str, required=True, help="The path of the output file.")
164
+ args = parser.parse_args()
165
+
166
+ print(f"start_commit: {args.start_commit}")
167
+ print(f"end_commit: {args.end_commit}")
168
+
169
+ if len({args.test is None, args.file is None}) != 2:
170
+ raise ValueError("Exactly one argument `test` or `file` must be specified.")
171
+
172
+ if args.test is not None:
173
+ commit = find_bad_commit(target_test=args.test, start_commit=args.start_commit, end_commit=args.end_commit)
174
+ with open(args.output_file, "w", encoding="UTF-8") as fp:
175
+ fp.write(f"{args.test}\n{commit}")
176
+ elif os.path.isfile(args.file):
177
+ with open(args.file, "r", encoding="UTF-8") as fp:
178
+ reports = json.load(fp)
179
+
180
+ for model in reports:
181
+ # TODO: make this script able to deal with both `single-gpu` and `multi-gpu` via a new argument.
182
+ reports[model].pop("multi-gpu", None)
183
+ failed_tests = reports[model]["single-gpu"]
184
+
185
+ failed_tests_with_bad_commits = []
186
+ for test in failed_tests:
187
+ commit = find_bad_commit(target_test=test, start_commit=args.start_commit, end_commit=args.end_commit)
188
+ info = {"test": test, "commit": commit}
189
+ info.update(get_commit_info(commit))
190
+ failed_tests_with_bad_commits.append(info)
191
+
192
+ # If no single-gpu test failures, remove the key
193
+ if len(failed_tests_with_bad_commits) > 0:
194
+ reports[model]["single-gpu"] = failed_tests_with_bad_commits
195
+ else:
196
+ reports[model].pop("single-gpu", None)
197
+
198
+ # remove the models without any test failure
199
+ reports = {k: v for k, v in reports.items() if len(v) > 0}
200
+
201
+ with open(args.output_file, "w", encoding="UTF-8") as fp:
202
+ json.dump(reports, fp, ensure_ascii=False, indent=4)
docs/transformers/utils/check_build.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import importlib
17
+ from pathlib import Path
18
+
19
+
20
+ # Test all the extensions added in the setup
21
+ FILES_TO_FIND = [
22
+ "kernels/rwkv/wkv_cuda.cu",
23
+ "kernels/rwkv/wkv_op.cpp",
24
+ "kernels/falcon_mamba/selective_scan_with_ln_interface.py",
25
+ "kernels/falcon_mamba/__init__.py",
26
+ "kernels/__init__.py",
27
+ "models/graphormer/algos_graphormer.pyx",
28
+ ]
29
+
30
+
31
+ def test_custom_files_are_present(transformers_path):
32
+ # Test all the extensions added in the setup
33
+ for file in FILES_TO_FIND:
34
+ if not (transformers_path / file).exists():
35
+ return False
36
+ return True
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+ parser.add_argument("--check_lib", action="store_true", help="Whether to check the build or the actual package.")
42
+ args = parser.parse_args()
43
+ if args.check_lib:
44
+ transformers_module = importlib.import_module("transformers")
45
+ transformers_path = Path(transformers_module.__file__).parent
46
+ else:
47
+ transformers_path = Path.cwd() / "build/lib/transformers"
48
+ if not test_custom_files_are_present(transformers_path):
49
+ raise ValueError("The built release does not contain the custom files. Fix this before going further!")
docs/transformers/utils/check_config_attributes.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import os
18
+ import re
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import direct_transformers_import
22
+
23
+
24
+ # All paths are set with the intent you should run this script from the root of the repo with the command
25
+ # python utils/check_config_docstrings.py
26
+ PATH_TO_TRANSFORMERS = "src/transformers"
27
+
28
+
29
+ # This is to make sure the transformers module imported is the one in the repo.
30
+ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
31
+
32
+ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
33
+
34
+ SPECIAL_CASES_TO_ALLOW = {
35
+ # 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
36
+ # periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
37
+ "BambaConfig": [
38
+ "attn_layer_indices",
39
+ ],
40
+ "JambaConfig": [
41
+ "max_position_embeddings",
42
+ "attn_layer_offset",
43
+ "attn_layer_period",
44
+ "expert_layer_offset",
45
+ "expert_layer_period",
46
+ ],
47
+ "Qwen2Config": ["use_sliding_window"],
48
+ "Qwen2MoeConfig": ["use_sliding_window"],
49
+ "Qwen2VLConfig": ["use_sliding_window"],
50
+ # `cache_implementation` should be in the default generation config, but we don't yet support per-model
51
+ # generation configs (TODO joao)
52
+ "Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
53
+ "Cohere2Config": ["cache_implementation"],
54
+ # Dropout with this value was declared but never used
55
+ "Phi3Config": ["embd_pdrop"],
56
+ # used to compute the property `self.chunk_length`
57
+ "EncodecConfig": ["overlap"],
58
+ # used to compute the property `self.layers_block_type`
59
+ "RecurrentGemmaConfig": ["block_types"],
60
+ # used as in the config to define `intermediate_size`
61
+ "MambaConfig": ["expand"],
62
+ # used as in the config to define `intermediate_size`
63
+ "FalconMambaConfig": ["expand"],
64
+ # used as `self.bert_model = BertModel(config, ...)`
65
+ "DPRConfig": True,
66
+ "FuyuConfig": True,
67
+ # not used in modeling files, but it's an important information
68
+ "FSMTConfig": ["langs"],
69
+ # used internally in the configuration class file
70
+ "GPTNeoConfig": ["attention_types"],
71
+ # used internally in the configuration class file
72
+ "EsmConfig": ["is_folding_model"],
73
+ # used during training (despite we don't have training script for these models yet)
74
+ "Mask2FormerConfig": ["ignore_value"],
75
+ # `ignore_value` used during training (despite we don't have training script for these models yet)
76
+ # `norm` used in conversion script (despite not using in the modeling file)
77
+ "OneFormerConfig": ["ignore_value", "norm"],
78
+ # used internally in the configuration class file
79
+ "T5Config": ["feed_forward_proj"],
80
+ # used internally in the configuration class file
81
+ # `tokenizer_class` get default value `T5Tokenizer` intentionally
82
+ "MT5Config": ["feed_forward_proj", "tokenizer_class"],
83
+ "UMT5Config": ["feed_forward_proj", "tokenizer_class"],
84
+ # used internally in the configuration class file
85
+ "LongT5Config": ["feed_forward_proj"],
86
+ # used internally in the configuration class file
87
+ "Pop2PianoConfig": ["feed_forward_proj"],
88
+ # used internally in the configuration class file
89
+ "SwitchTransformersConfig": ["feed_forward_proj"],
90
+ # having default values other than `1e-5` - we can't fix them without breaking
91
+ "BioGptConfig": ["layer_norm_eps"],
92
+ # having default values other than `1e-5` - we can't fix them without breaking
93
+ "GLPNConfig": ["layer_norm_eps"],
94
+ # having default values other than `1e-5` - we can't fix them without breaking
95
+ "SegformerConfig": ["layer_norm_eps"],
96
+ # having default values other than `1e-5` - we can't fix them without breaking
97
+ "CvtConfig": ["layer_norm_eps"],
98
+ # having default values other than `1e-5` - we can't fix them without breaking
99
+ "PerceiverConfig": ["layer_norm_eps"],
100
+ # used internally to calculate the feature size
101
+ "InformerConfig": ["num_static_real_features", "num_time_features"],
102
+ # used internally to calculate the feature size
103
+ "TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
104
+ # used internally to calculate the feature size
105
+ "AutoformerConfig": ["num_static_real_features", "num_time_features"],
106
+ # used internally to calculate `mlp_dim`
107
+ "SamVisionConfig": ["mlp_ratio"],
108
+ # For (head) training, but so far not implemented
109
+ "ClapAudioConfig": ["num_classes"],
110
+ # Not used, but providing useful information to users
111
+ "SpeechT5HifiGanConfig": ["sampling_rate"],
112
+ # used internally in the configuration class file
113
+ "UdopConfig": ["feed_forward_proj"],
114
+ # Actually used in the config or generation config, in that case necessary for the sub-components generation
115
+ "SeamlessM4TConfig": [
116
+ "max_new_tokens",
117
+ "t2u_max_new_tokens",
118
+ "t2u_decoder_attention_heads",
119
+ "t2u_decoder_ffn_dim",
120
+ "t2u_decoder_layers",
121
+ "t2u_encoder_attention_heads",
122
+ "t2u_encoder_ffn_dim",
123
+ "t2u_encoder_layers",
124
+ "t2u_max_position_embeddings",
125
+ ],
126
+ # Actually used in the config or generation config, in that case necessary for the sub-components generation
127
+ "SeamlessM4Tv2Config": [
128
+ "max_new_tokens",
129
+ "t2u_decoder_attention_heads",
130
+ "t2u_decoder_ffn_dim",
131
+ "t2u_decoder_layers",
132
+ "t2u_encoder_attention_heads",
133
+ "t2u_encoder_ffn_dim",
134
+ "t2u_encoder_layers",
135
+ "t2u_max_position_embeddings",
136
+ "t2u_variance_pred_dropout",
137
+ "t2u_variance_predictor_embed_dim",
138
+ "t2u_variance_predictor_hidden_dim",
139
+ "t2u_variance_predictor_kernel_size",
140
+ ],
141
+ "ZambaConfig": [
142
+ "tie_word_embeddings",
143
+ "attn_layer_offset",
144
+ "attn_layer_period",
145
+ ],
146
+ "MllamaTextConfig": [
147
+ "initializer_range",
148
+ ],
149
+ "MllamaVisionConfig": [
150
+ "initializer_range",
151
+ "supported_aspect_ratios",
152
+ ],
153
+ "ConditionalDetrConfig": [
154
+ "bbox_cost",
155
+ "bbox_loss_coefficient",
156
+ "class_cost",
157
+ "cls_loss_coefficient",
158
+ "dice_loss_coefficient",
159
+ "focal_alpha",
160
+ "giou_cost",
161
+ "giou_loss_coefficient",
162
+ "mask_loss_coefficient",
163
+ ],
164
+ "DabDetrConfig": [
165
+ "dilation",
166
+ "bbox_cost",
167
+ "bbox_loss_coefficient",
168
+ "class_cost",
169
+ "cls_loss_coefficient",
170
+ "focal_alpha",
171
+ "giou_cost",
172
+ "giou_loss_coefficient",
173
+ ],
174
+ "DetrConfig": [
175
+ "bbox_cost",
176
+ "bbox_loss_coefficient",
177
+ "class_cost",
178
+ "dice_loss_coefficient",
179
+ "eos_coefficient",
180
+ "giou_cost",
181
+ "giou_loss_coefficient",
182
+ "mask_loss_coefficient",
183
+ ],
184
+ "GroundingDinoConfig": [
185
+ "bbox_cost",
186
+ "bbox_loss_coefficient",
187
+ "class_cost",
188
+ "focal_alpha",
189
+ "giou_cost",
190
+ "giou_loss_coefficient",
191
+ ],
192
+ "RTDetrConfig": [
193
+ "eos_coefficient",
194
+ "focal_loss_alpha",
195
+ "focal_loss_gamma",
196
+ "matcher_alpha",
197
+ "matcher_bbox_cost",
198
+ "matcher_class_cost",
199
+ "matcher_gamma",
200
+ "matcher_giou_cost",
201
+ "use_focal_loss",
202
+ "weight_loss_bbox",
203
+ "weight_loss_giou",
204
+ "weight_loss_vfl",
205
+ ],
206
+ "RTDetrV2Config": [
207
+ "eos_coefficient",
208
+ "focal_loss_alpha",
209
+ "focal_loss_gamma",
210
+ "matcher_alpha",
211
+ "matcher_bbox_cost",
212
+ "matcher_class_cost",
213
+ "matcher_gamma",
214
+ "matcher_giou_cost",
215
+ "use_focal_loss",
216
+ "weight_loss_bbox",
217
+ "weight_loss_giou",
218
+ "weight_loss_vfl",
219
+ ],
220
+ "YolosConfig": [
221
+ "bbox_cost",
222
+ "bbox_loss_coefficient",
223
+ "class_cost",
224
+ "eos_coefficient",
225
+ "giou_cost",
226
+ "giou_loss_coefficient",
227
+ ],
228
+ "GPTNeoXConfig": ["rotary_emb_base"],
229
+ "Gemma3Config": ["boi_token_index", "eoi_token_index"],
230
+ "Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"],
231
+ "ShieldGemma2Config": [
232
+ "boi_token_index",
233
+ "eoi_token_index",
234
+ "initializer_range",
235
+ "mm_tokens_per_image",
236
+ "text_config",
237
+ "vision_config",
238
+ ],
239
+ "Llama4Config": ["boi_token_index", "eoi_token_index"],
240
+ "Llama4TextConfig": [
241
+ "interleave_moe_layer_step",
242
+ "no_rope_layer_interval",
243
+ "no_rope_layers",
244
+ "output_router_logits",
245
+ "router_aux_loss_coef",
246
+ "router_jitter_noise",
247
+ "cache_implementation",
248
+ ],
249
+ "Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
250
+ }
251
+
252
+
253
+ # TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
254
+ SPECIAL_CASES_TO_ALLOW.update(
255
+ {
256
+ "CLIPSegConfig": True,
257
+ "DeformableDetrConfig": True,
258
+ "DinatConfig": True,
259
+ "DonutSwinConfig": True,
260
+ "FastSpeech2ConformerConfig": True,
261
+ "FSMTConfig": True,
262
+ "LayoutLMv2Config": True,
263
+ "MaskFormerSwinConfig": True,
264
+ "MT5Config": True,
265
+ # For backward compatibility with trust remote code models
266
+ "MptConfig": True,
267
+ "MptAttentionConfig": True,
268
+ "OneFormerConfig": True,
269
+ "PerceiverConfig": True,
270
+ "RagConfig": True,
271
+ "SpeechT5Config": True,
272
+ "SwinConfig": True,
273
+ "Swin2SRConfig": True,
274
+ "Swinv2Config": True,
275
+ "SwitchTransformersConfig": True,
276
+ "TableTransformerConfig": True,
277
+ "TapasConfig": True,
278
+ "UniSpeechConfig": True,
279
+ "UniSpeechSatConfig": True,
280
+ "WavLMConfig": True,
281
+ "WhisperConfig": True,
282
+ # TODO: @Arthur (for `alignment_head` and `alignment_layer`)
283
+ "JukeboxPriorConfig": True,
284
+ # TODO: @Younes (for `is_decoder`)
285
+ "Pix2StructTextConfig": True,
286
+ "IdeficsConfig": True,
287
+ "IdeficsVisionConfig": True,
288
+ "IdeficsPerceiverConfig": True,
289
+ }
290
+ )
291
+
292
+
293
+ def check_attribute_being_used(config_class, attributes, default_value, source_strings):
294
+ """Check if any name in `attributes` is used in one of the strings in `source_strings`
295
+
296
+ Args:
297
+ config_class (`type`):
298
+ The configuration class for which the arguments in its `__init__` will be checked.
299
+ attributes (`List[str]`):
300
+ The name of an argument (or attribute) and its variant names if any.
301
+ default_value (`Any`):
302
+ A default value for the attribute in `attributes` assigned in the `__init__` of `config_class`.
303
+ source_strings (`List[str]`):
304
+ The python source code strings in the same modeling directory where `config_class` is defined. The file
305
+ containing the definition of `config_class` should be excluded.
306
+ """
307
+ attribute_used = False
308
+ for attribute in attributes:
309
+ for modeling_source in source_strings:
310
+ # check if we can find `config.xxx`, `getattr(config, "xxx", ...)` or `getattr(self.config, "xxx", ...)`
311
+ if (
312
+ f"config.{attribute}" in modeling_source
313
+ or f'getattr(config, "{attribute}"' in modeling_source
314
+ or f'getattr(self.config, "{attribute}"' in modeling_source
315
+ or (
316
+ "TextConfig" in config_class.__name__
317
+ and f"config.get_text_config().{attribute}" in modeling_source
318
+ )
319
+ ):
320
+ attribute_used = True
321
+ # Deal with multi-line cases
322
+ elif (
323
+ re.search(
324
+ rf'getattr[ \t\v\n\r\f]*\([ \t\v\n\r\f]*(self\.)?config,[ \t\v\n\r\f]*"{attribute}"',
325
+ modeling_source,
326
+ )
327
+ is not None
328
+ ):
329
+ attribute_used = True
330
+ if attribute_used:
331
+ break
332
+ if attribute_used:
333
+ break
334
+
335
+ # common and important attributes, even if they do not always appear in the modeling files
336
+ attributes_to_allow = [
337
+ "initializer_range",
338
+ "bos_index",
339
+ "eos_index",
340
+ "pad_index",
341
+ "unk_index",
342
+ "mask_index",
343
+ "image_token_id", # for VLMs
344
+ "video_token_id",
345
+ "image_seq_length",
346
+ "video_seq_length",
347
+ "image_size",
348
+ "text_config", # may appear as `get_text_config()`
349
+ "use_cache",
350
+ "out_features",
351
+ "out_indices",
352
+ "sampling_rate",
353
+ # backbone related arguments passed to load_backbone
354
+ "use_pretrained_backbone",
355
+ "backbone",
356
+ "backbone_config",
357
+ "use_timm_backbone",
358
+ "backbone_kwargs",
359
+ # rope attributes may not appear directly in the modeling but are used
360
+ "rope_theta",
361
+ "partial_rotary_factor",
362
+ "pretraining_tp",
363
+ "boi_token_index",
364
+ "eoi_token_index",
365
+ ]
366
+ attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
367
+
368
+ # Special cases to be allowed
369
+ case_allowed = True
370
+ if not attribute_used:
371
+ case_allowed = False
372
+ for attribute in attributes:
373
+ # Allow if the default value in the configuration class is different from the one in `PretrainedConfig`
374
+ if attribute in ["is_encoder_decoder"] and default_value is True:
375
+ case_allowed = True
376
+ elif attribute in ["tie_word_embeddings"] and default_value is False:
377
+ case_allowed = True
378
+
379
+ # Allow cases without checking the default value in the configuration class
380
+ elif attribute in attributes_to_allow + attributes_used_in_generation:
381
+ case_allowed = True
382
+ elif attribute.endswith("_token_id"):
383
+ case_allowed = True
384
+
385
+ # configuration class specific cases
386
+ if not case_allowed:
387
+ allowed_cases = SPECIAL_CASES_TO_ALLOW.get(config_class.__name__, [])
388
+ case_allowed = allowed_cases is True or attribute in allowed_cases
389
+
390
+ return attribute_used or case_allowed
391
+
392
+
393
+ def check_config_attributes_being_used(config_class):
394
+ """Check the arguments in `__init__` of `config_class` are used in the modeling files in the same directory
395
+
396
+ Args:
397
+ config_class (`type`):
398
+ The configuration class for which the arguments in its `__init__` will be checked.
399
+ """
400
+ # Get the parameters in `__init__` of the configuration class, and the default values if any
401
+ signature = dict(inspect.signature(config_class.__init__).parameters)
402
+ parameter_names = [x for x in list(signature.keys()) if x not in ["self", "kwargs"]]
403
+ parameter_defaults = [signature[param].default for param in parameter_names]
404
+
405
+ # If `attribute_map` exists, an attribute can have different names to be used in the modeling files, and as long
406
+ # as one variant is used, the test should pass
407
+ reversed_attribute_map = {}
408
+ if len(config_class.attribute_map) > 0:
409
+ reversed_attribute_map = {v: k for k, v in config_class.attribute_map.items()}
410
+
411
+ # Get the path to modeling source files
412
+ config_source_file = inspect.getsourcefile(config_class)
413
+ model_dir = os.path.dirname(config_source_file)
414
+ # Let's check against all frameworks: as long as one framework uses an attribute, we are good.
415
+ modeling_paths = [os.path.join(model_dir, fn) for fn in os.listdir(model_dir) if fn.startswith("modeling_")]
416
+
417
+ # Get the source code strings
418
+ modeling_sources = []
419
+ for path in modeling_paths:
420
+ if os.path.isfile(path):
421
+ with open(path, encoding="utf8") as fp:
422
+ modeling_sources.append(fp.read())
423
+
424
+ unused_attributes = []
425
+ for config_param, default_value in zip(parameter_names, parameter_defaults):
426
+ # `attributes` here is all the variant names for `config_param`
427
+ attributes = [config_param]
428
+ # some configuration classes have non-empty `attribute_map`, and both names could be used in the
429
+ # corresponding modeling files. As long as one of them appears, it is fine.
430
+ if config_param in reversed_attribute_map:
431
+ attributes.append(reversed_attribute_map[config_param])
432
+
433
+ if not check_attribute_being_used(config_class, attributes, default_value, modeling_sources):
434
+ unused_attributes.append(attributes[0])
435
+
436
+ return sorted(unused_attributes)
437
+
438
+
439
+ def check_config_attributes():
440
+ """Check the arguments in `__init__` of all configuration classes are used in python files"""
441
+ configs_with_unused_attributes = {}
442
+ for _config_class in list(CONFIG_MAPPING.values()):
443
+ # Skip deprecated models
444
+ if "models.deprecated" in _config_class.__module__:
445
+ continue
446
+ # Some config classes are not in `CONFIG_MAPPING` (e.g. `CLIPVisionConfig`, `Blip2VisionConfig`, etc.)
447
+ config_classes_in_module = [
448
+ cls
449
+ for name, cls in inspect.getmembers(
450
+ inspect.getmodule(_config_class),
451
+ lambda x: inspect.isclass(x)
452
+ and issubclass(x, PretrainedConfig)
453
+ and inspect.getmodule(x) == inspect.getmodule(_config_class),
454
+ )
455
+ ]
456
+ for config_class in config_classes_in_module:
457
+ unused_attributes = check_config_attributes_being_used(config_class)
458
+ if len(unused_attributes) > 0:
459
+ configs_with_unused_attributes[config_class.__name__] = unused_attributes
460
+
461
+ if len(configs_with_unused_attributes) > 0:
462
+ error = "The following configuration classes contain unused attributes in the corresponding modeling files:\n"
463
+ for name, attributes in configs_with_unused_attributes.items():
464
+ error += f"{name}: {attributes}\n"
465
+
466
+ raise ValueError(error)
467
+
468
+
469
+ if __name__ == "__main__":
470
+ check_config_attributes()
docs/transformers/utils/check_config_docstrings.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import re
18
+
19
+ from transformers.utils import direct_transformers_import
20
+
21
+
22
+ # All paths are set with the intent you should run this script from the root of the repo with the command
23
+ # python utils/check_config_docstrings.py
24
+ PATH_TO_TRANSFORMERS = "src/transformers"
25
+
26
+
27
+ # This is to make sure the transformers module imported is the one in the repo.
28
+ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
29
+
30
+ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
31
+
32
+ # Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
33
+ # For example, `[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)`
34
+ _re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
35
+
36
+
37
+ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
38
+ "DecisionTransformerConfig",
39
+ "EncoderDecoderConfig",
40
+ "MusicgenConfig",
41
+ "RagConfig",
42
+ "SpeechEncoderDecoderConfig",
43
+ "TimmBackboneConfig",
44
+ "TimmWrapperConfig",
45
+ "VisionEncoderDecoderConfig",
46
+ "VisionTextDualEncoderConfig",
47
+ "LlamaConfig",
48
+ "GraniteConfig",
49
+ "GraniteMoeConfig",
50
+ "Qwen3MoeConfig",
51
+ "GraniteSpeechConfig",
52
+ }
53
+
54
+
55
+ def get_checkpoint_from_config_class(config_class):
56
+ checkpoint = None
57
+
58
+ # source code of `config_class`
59
+ config_source = inspect.getsource(config_class)
60
+ checkpoints = _re_checkpoint.findall(config_source)
61
+
62
+ # Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
63
+ # For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
64
+ for ckpt_name, ckpt_link in checkpoints:
65
+ # allow the link to end with `/`
66
+ if ckpt_link.endswith("/"):
67
+ ckpt_link = ckpt_link[:-1]
68
+
69
+ # verify the checkpoint name corresponds to the checkpoint link
70
+ ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
71
+ if ckpt_link == ckpt_link_from_name:
72
+ checkpoint = ckpt_name
73
+ break
74
+
75
+ return checkpoint
76
+
77
+
78
+ def check_config_docstrings_have_checkpoints():
79
+ configs_without_checkpoint = []
80
+
81
+ for config_class in list(CONFIG_MAPPING.values()):
82
+ # Skip deprecated models
83
+ if "models.deprecated" in config_class.__module__:
84
+ continue
85
+ checkpoint = get_checkpoint_from_config_class(config_class)
86
+
87
+ name = config_class.__name__
88
+ if checkpoint is None and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
89
+ configs_without_checkpoint.append(name)
90
+
91
+ if len(configs_without_checkpoint) > 0:
92
+ message = "\n".join(sorted(configs_without_checkpoint))
93
+ raise ValueError(
94
+ f"The following configurations don't contain any valid checkpoint:\n{message}\n\n"
95
+ "The requirement is to include a link pointing to one of the models of this architecture in the "
96
+ "docstring of the config classes listed above. The link should have be a markdown format like "
97
+ "[myorg/mymodel](https://huggingface.co/myorg/mymodel)."
98
+ )
99
+
100
+
101
+ if __name__ == "__main__":
102
+ check_config_docstrings_have_checkpoints()
docs/transformers/utils/check_copies.py ADDED
@@ -0,0 +1,1078 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Utility that checks whether the copies defined in the library match the original or not. This includes:
17
+ - All code commented with `# Copied from` comments,
18
+ - The list of models in the main README.md matches the ones in the localized READMEs,
19
+ - Files that are registered as full copies of one another in the `FULL_COPIES` constant of this script.
20
+
21
+ This also checks the list of models in the README is complete (has all models) and add a line to complete if there is
22
+ a model missing.
23
+
24
+ Use from the root of the repo with:
25
+
26
+ ```bash
27
+ python utils/check_copies.py
28
+ ```
29
+
30
+ for a check that will error in case of inconsistencies (used by `make repo-consistency`) or
31
+
32
+ ```bash
33
+ python utils/check_copies.py --fix_and_overwrite
34
+ ```
35
+
36
+ for a check that will fix all inconsistencies automatically (used by `make fix-copies`).
37
+ """
38
+
39
+ import argparse
40
+ import glob
41
+ import os
42
+ import re
43
+ import subprocess
44
+ from collections import OrderedDict
45
+ from typing import List, Optional, Tuple, Union
46
+
47
+ from transformers.utils import direct_transformers_import
48
+
49
+
50
+ # All paths are set with the intent you should run this script from the root of the repo with the command
51
+ # python utils/check_copies.py
52
+ TRANSFORMERS_PATH = "src/transformers"
53
+ MODEL_TEST_PATH = "tests/models"
54
+ PATH_TO_DOCS = "docs/source/en"
55
+ REPO_PATH = "."
56
+
57
+ # Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with)
58
+ FULL_COPIES = {
59
+ "examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
60
+ "examples/flax/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py",
61
+ }
62
+
63
+
64
+ LOCALIZED_READMES = {
65
+ # If the introduction or the conclusion of the list change, the prompts may need to be updated.
66
+ "README.md": {
67
+ "start_prompt": "🤗 Transformers currently provides the following architectures",
68
+ "end_prompt": "1. Want to contribute a new model?",
69
+ "format_model_list": (
70
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
71
+ " {paper_authors}.{supplements}"
72
+ ),
73
+ },
74
+ "README_zh-hans.md": {
75
+ "start_prompt": "🤗 Transformers 目前支持如下的架构",
76
+ "end_prompt": "1. 想要贡献新的模型?",
77
+ "format_model_list": (
78
+ "**[{title}]({model_link})** (来自 {paper_affiliations}) 伴随论文 {paper_title_link} 由 {paper_authors}"
79
+ " 发布。{supplements}"
80
+ ),
81
+ },
82
+ "README_zh-hant.md": {
83
+ "start_prompt": "🤗 Transformers 目前支援以下的架構",
84
+ "end_prompt": "1. 想要貢獻新的模型?",
85
+ "format_model_list": (
86
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
87
+ " {paper_authors}.{supplements}"
88
+ ),
89
+ },
90
+ "README_ko.md": {
91
+ "start_prompt": "🤗 Transformers는 다음 모델들을 제공합니다",
92
+ "end_prompt": "1. 새로운 모델을 올리고 싶나요?",
93
+ "format_model_list": (
94
+ "**[{title}]({model_link})** ({paper_affiliations} 에서 제공)은 {paper_authors}.{supplements}의"
95
+ " {paper_title_link}논문과 함께 발표했습니다."
96
+ ),
97
+ },
98
+ "README_es.md": {
99
+ "start_prompt": "🤗 Transformers actualmente proporciona las siguientes arquitecturas",
100
+ "end_prompt": "1. ¿Quieres aportar un nuevo modelo?",
101
+ "format_model_list": (
102
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
103
+ " {paper_authors}.{supplements}"
104
+ ),
105
+ },
106
+ "README_ja.md": {
107
+ "start_prompt": "🤗Transformersは現在、以下のアーキテクチャを提供しています",
108
+ "end_prompt": "1. 新しいモデルを投稿したいですか?",
109
+ "format_model_list": (
110
+ "**[{title}]({model_link})** ({paper_affiliations} から) {paper_authors}.{supplements} から公開された研究論文"
111
+ " {paper_title_link}"
112
+ ),
113
+ },
114
+ "README_hd.md": {
115
+ "start_prompt": "🤗 ट्रांसफॉर्मर वर्तमान में निम्नलिखित आर्किटेक्चर का समर्थन करते हैं",
116
+ "end_prompt": "1. एक नए मॉडल में योगदान देना चाहते हैं?",
117
+ "format_model_list": (
118
+ "**[{title}]({model_link})** ({paper_affiliations} से) {paper_authors}.{supplements} द्वारा"
119
+ "अनुसंधान पत्र {paper_title_link} के साथ जारी किया गया"
120
+ ),
121
+ },
122
+ "README_ru.md": {
123
+ "start_prompt": "🤗 В настоящее время Transformers предоставляет следующие архитектуры",
124
+ "end_prompt": "1. Хотите внести новую модель?",
125
+ "format_model_list": (
126
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
127
+ " {paper_authors}.{supplements}"
128
+ ),
129
+ },
130
+ "README_pt-br.md": {
131
+ "start_prompt": "🤗 Transformers atualmente fornece as seguintes arquiteturas",
132
+ "end_prompt": "1. Quer contribuir com um novo modelo?",
133
+ "format_model_list": (
134
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
135
+ " {paper_authors}.{supplements}"
136
+ ),
137
+ },
138
+ "README_te.md": {
139
+ "start_prompt": "🤗 ట్రాన్స్‌ఫార్మర్లు ప్రస్తుతం కింది ఆర్కిటెక్చర్‌లను అందజేస్తున్నాయి",
140
+ "end_prompt": "1. కొత్త మోడల్‌ను అందించాలనుకుంటున్నారా?",
141
+ "format_model_list": (
142
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
143
+ " {paper_authors}.{supplements}"
144
+ ),
145
+ },
146
+ "README_fr.md": {
147
+ "start_prompt": "🤗 Transformers fournit actuellement les architectures suivantes",
148
+ "end_prompt": "1. Vous souhaitez contribuer avec un nouveau modèle ?",
149
+ "format_model_list": (
150
+ "**[{title}]({model_link})** (de {paper_affiliations}) publié dans l'article {paper_title_link} par"
151
+ "{paper_authors}.{supplements}"
152
+ ),
153
+ },
154
+ "README_de.md": {
155
+ "start_prompt": "🤗 Transformers bietet derzeit die folgenden Architekturen an",
156
+ "end_prompt": "1. Möchten Sie ein neues Modell beitragen?",
157
+ "format_model_list": (
158
+ "**[{title}]({model_link})** (from {paper_affiliations}) released with the paper {paper_title_link} by"
159
+ " {paper_authors}.{supplements}"
160
+ ),
161
+ },
162
+ "README_vi.md": {
163
+ "start_prompt": "🤗 Transformers hiện đang cung cấp các kiến trúc sau đây",
164
+ "end_prompt": "1. Muốn đóng góp một mô hình mới?",
165
+ "format_model_list": (
166
+ "**[{title}]({model_link})** (từ {paper_affiliations}) được phát hành với bài báo {paper_title_link} by"
167
+ " {paper_authors}.{supplements}"
168
+ ),
169
+ },
170
+ }
171
+
172
+ # This is to make sure the transformers module imported is the one in the repo.
173
+ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
174
+
175
+
176
+ def _is_definition_header_ending_line(line: str) -> bool:
177
+ # Helper function. Returns `True` if `line` is the end parenthesis of a class/function definition
178
+ return re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
179
+
180
+
181
+ def _should_continue(line: str, indent: str) -> bool:
182
+ # Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
183
+ # class/function definition
184
+ return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line)
185
+
186
+
187
+ def _sanity_check_splits(splits_1, splits_2, is_class, filename):
188
+ """Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match.
189
+
190
+ For the case of `class`, they must be of one of the following 3 cases:
191
+
192
+ - a single block without name:
193
+
194
+ class foo:
195
+ a = 1
196
+
197
+ - a consecutive sequence of (1 or more) blocks with name
198
+
199
+ class foo:
200
+
201
+ def f(x):
202
+ return x
203
+
204
+ - a block without name, followed by a consecutive sequence of (1 or more) blocks with name
205
+
206
+ class foo:
207
+ a = 1
208
+
209
+ def f(x):
210
+ return x
211
+
212
+ def g(x):
213
+ return None
214
+
215
+ The 2 code snippets that give `splits_1` and `splits_2` have to be in the same case to pass this check, but the
216
+ number of blocks with name in the consecutive sequence is not taken into account.
217
+
218
+ For the case of `function or method`, we don't require it to be in one of the above 3 cases. However, the structure
219
+ of`splits_1` and `splits_2` have to match exactly. In particular, the number of blocks with name in a consecutive
220
+ sequence is taken into account.
221
+ """
222
+ block_names_1 = []
223
+ block_names_2 = []
224
+
225
+ for block in splits_1[1:]:
226
+ if block[0].startswith("_block_without_name_"):
227
+ block_names_1.append("block_without_name")
228
+ elif not block[0].startswith("_empty_block_") and (
229
+ not is_class or len(block_names_1) == 0 or block_names_1[-1].startswith("block_without_name")
230
+ ):
231
+ block_names_1.append("block_with_name")
232
+
233
+ for block in splits_2[1:]:
234
+ if block[0].startswith("_block_without_name_"):
235
+ block_names_2.append("block_without_name")
236
+ elif not block[0].startswith("_empty_block_") and (
237
+ not is_class or len(block_names_2) == 0 or block_names_2[-1].startswith("block_without_name")
238
+ ):
239
+ block_names_2.append("block_with_name")
240
+
241
+ if is_class:
242
+ if block_names_1 not in [
243
+ ["block_without_name"],
244
+ ["block_with_name"],
245
+ ["block_without_name", "block_with_name"],
246
+ ]:
247
+ raise ValueError(
248
+ f"""Class defined in {filename} doesn't have the expected structure.
249
+ See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`""",
250
+ )
251
+
252
+ if block_names_1 != block_names_2:
253
+ raise ValueError(f"In {filename}, two code blocks expected to be copies have different structures.")
254
+
255
+
256
+ def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
257
+ """
258
+ Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`).
259
+
260
+ Args:
261
+ lines (`List[str]`):
262
+ The source code, represented by a list of lines.
263
+ start_index (`int`):
264
+ The starting index of the target class/func block.
265
+ indent (`int`):
266
+ The indent of the class/func body.
267
+
268
+ Returns:
269
+ `int`: The index of the block's ending line plus by 1 (i.e. exclusive).
270
+ """
271
+ indent = " " * indent
272
+ # enter the block body
273
+ line_index = start_index + 1
274
+
275
+ while line_index < len(lines) and _should_continue(lines[line_index], indent):
276
+ line_index += 1
277
+ # Clean up empty lines at the end (if any).
278
+ while len(lines[line_index - 1]) <= 1:
279
+ line_index -= 1
280
+
281
+ return line_index
282
+
283
+
284
+ def split_code_into_blocks(
285
+ lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
286
+ ) -> List[Tuple[str, int, int]]:
287
+ """
288
+ Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*.
289
+
290
+ The block's header is included as the first element. The contiguous regions (without empty lines) that are not
291
+ inside any inner block are included as blocks. The contiguous regions of empty lines that are not inside any inner
292
+ block are also included as (dummy) blocks.
293
+
294
+ Args:
295
+ lines (`List[str]`):
296
+ The source code, represented by a list of lines.
297
+ start_index (`int`):
298
+ The starting index of the target class/func block.
299
+ end_index (`int`):
300
+ The ending index of the target class/func block.
301
+ indent (`int`):
302
+ The indent of the class/func body.
303
+ backtrace (`bool`, *optional*, defaults to `False`):
304
+ Whether or not to include the lines before the inner class/func block's header (e.g. comments, decorators,
305
+ etc.) until an empty line is encountered.
306
+
307
+ Returns:
308
+ `List[Tuple[str, int, int]]`: A list of elements with the form `(block_name, start_index, end_index)`.
309
+ """
310
+ splits = []
311
+ # `indent - 4` is the indent level of the target class/func header
312
+ try:
313
+ target_block_name = re.search(
314
+ rf"^{' ' * (indent - 4)}((class|def)\s+\S+)(\(|\:)", lines[start_index]
315
+ ).groups()[0]
316
+ except Exception:
317
+ start_context = min(start_index - 10, 0)
318
+ end_context = min(end_index + 10, len(lines))
319
+ raise ValueError(
320
+ f"Tried to split a class or function. It did not work. Error comes from line {start_index}: \n```\n"
321
+ + "".join(lines[start_context:end_context])
322
+ + "```\n"
323
+ )
324
+
325
+ # from now on, the `block` means inner blocks unless explicitly specified
326
+ indent_str = " " * indent
327
+ block_without_name_idx = 0
328
+ empty_block_idx = 0
329
+
330
+ # Find the lines for the definition header
331
+ index = start_index
332
+ if "(" in lines[start_index] and "):" not in lines[start_index] in lines[start_index]:
333
+ while index < end_index:
334
+ if _is_definition_header_ending_line(lines[index]):
335
+ break
336
+ index += 1
337
+
338
+ # the first line outside the definition header
339
+ index += 1
340
+ splits.append((target_block_name, start_index, index))
341
+
342
+ block_start_index, prev_block_end_index = index, index
343
+ while index < end_index:
344
+ # if found, it will be an inner block
345
+ block_found = re.search(rf"^{indent_str}((class|def)\s+\S+)(\(|\:)", lines[index])
346
+ if block_found:
347
+ name = block_found.groups()[0]
348
+
349
+ block_end_index = find_block_end(lines, index, indent + 4)
350
+
351
+ # backtrace to include the lines before the found block's definition header (e.g. comments, decorators,
352
+ # etc.) until an empty line is encountered.
353
+ block_start_index = index
354
+ if index > prev_block_end_index and backtrace:
355
+ idx = index - 1
356
+ for idx in range(index - 1, prev_block_end_index - 2, -1):
357
+ if not (len(lines[idx].strip()) > 0 and lines[idx].startswith(indent_str)):
358
+ break
359
+ idx += 1
360
+ if idx < index:
361
+ block_start_index = idx
362
+
363
+ # between the current found block and the previous found block
364
+ if block_start_index > prev_block_end_index:
365
+ # give it a dummy name
366
+ if len("".join(lines[prev_block_end_index:block_start_index]).strip()) == 0:
367
+ prev_block_name = f"_empty_block_{empty_block_idx}"
368
+ empty_block_idx += 1
369
+ else:
370
+ prev_block_name = f"_block_without_name_{block_without_name_idx}"
371
+ block_without_name_idx += 1
372
+ # Add it as a block
373
+ splits.append((prev_block_name, prev_block_end_index, block_start_index))
374
+
375
+ # Add the current found block
376
+ splits.append((name, block_start_index, block_end_index))
377
+ prev_block_end_index = block_end_index
378
+ index = block_end_index - 1
379
+
380
+ index += 1
381
+
382
+ if index > prev_block_end_index:
383
+ if len("".join(lines[prev_block_end_index:index]).strip()) == 0:
384
+ prev_block_name = f"_empty_block_{empty_block_idx}"
385
+ else:
386
+ prev_block_name = f"_block_without_name_{block_without_name_idx}"
387
+ splits.append((prev_block_name, prev_block_end_index, index))
388
+
389
+ return splits
390
+
391
+
392
+ def find_code_in_transformers(
393
+ object_name: str, base_path: str = None, return_indices: bool = False
394
+ ) -> Union[str, Tuple[List[str], int, int]]:
395
+ """
396
+ Find and return the source code of an object.
397
+
398
+ Args:
399
+ object_name (`str`):
400
+ The name of the object we want the source code of.
401
+ base_path (`str`, *optional*):
402
+ The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
403
+ return_indices(`bool`, *optional*, defaults to `False`):
404
+ If `False`, will only return the code (as a string), otherwise it will also return the whole lines of the
405
+ file where the object specified by `object_name` is defined, together the start/end indices of the block in
406
+ the file that defines the object.
407
+
408
+ Returns:
409
+ `Union[str, Tuple[List[str], int, int]]`: If `return_indices=False`, only the source code of the object will be
410
+ returned. Otherwise, it also returns the whole lines of the file where the object specified by `object_name` is
411
+ defined, together the start/end indices of the block in the file that defines the object.
412
+ """
413
+ parts = object_name.split(".")
414
+ i = 0
415
+
416
+ # We can't set this as the default value in the argument, otherwise `CopyCheckTester` will fail, as it uses a
417
+ # patched temp directory.
418
+ if base_path is None:
419
+ base_path = TRANSFORMERS_PATH
420
+
421
+ # Detail: the `Copied from` statement is originally designed to work with the last part of `TRANSFORMERS_PATH`,
422
+ # (which is `transformers`). The same should be applied for `MODEL_TEST_PATH`. However, its last part is `models`
423
+ # (to only check and search in it) which is a bit confusing. So we keep the copied statement staring with
424
+ # `tests.models.` and change it to `tests` here.
425
+ if base_path == MODEL_TEST_PATH:
426
+ base_path = "tests"
427
+
428
+ # First let's find the module where our object lives.
429
+ module = parts[i]
430
+ while i < len(parts) and not os.path.isfile(os.path.join(base_path, f"{module}.py")):
431
+ i += 1
432
+ if i < len(parts):
433
+ module = os.path.join(module, parts[i])
434
+ if i >= len(parts):
435
+ raise ValueError(
436
+ f"`object_name` should begin with the name of a module of transformers but got {object_name}."
437
+ )
438
+
439
+ with open(os.path.join(base_path, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f:
440
+ lines = f.readlines()
441
+
442
+ # Now let's find the class / func in the code!
443
+ indent = ""
444
+ line_index = 0
445
+ for name in parts[i + 1 :]:
446
+ while (
447
+ line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
448
+ ):
449
+ line_index += 1
450
+ # find the target specified in the current level in `parts` -> increase `indent` so we can search the next
451
+ indent += " "
452
+ # the index of the first line in the (currently found) block *body*
453
+ line_index += 1
454
+
455
+ if line_index >= len(lines):
456
+ raise ValueError(f" {object_name} does not match any function or class in {module}.")
457
+
458
+ # `indent` is already one level deeper than the (found) class/func block's definition header
459
+
460
+ # We found the beginning of the class / func, now let's find the end (when the indent diminishes).
461
+ # `start_index` is the index of the class/func block's definition header
462
+ start_index = line_index - 1
463
+ end_index = find_block_end(lines, start_index, len(indent))
464
+
465
+ code = "".join(lines[start_index:end_index])
466
+ return (code, (lines, start_index, end_index)) if return_indices else code
467
+
468
+
469
+ def replace_code(code: str, replace_pattern: str) -> str:
470
+ """Replace `code` by a pattern of the form `with X1->X2,Y1->Y2,Z1->Z2`.
471
+
472
+ Args:
473
+ code (`str`): The code to be modified.
474
+ replace_pattern (`str`): The pattern used to modify `code`.
475
+
476
+ Returns:
477
+ `str`: The modified code.
478
+ """
479
+ if len(replace_pattern) > 0:
480
+ patterns = replace_pattern.replace("with", "").split(",")
481
+ patterns = [_re_replace_pattern.search(p) for p in patterns]
482
+ for pattern in patterns:
483
+ if pattern is None:
484
+ continue
485
+ obj1, obj2, option = pattern.groups()
486
+ code = re.sub(obj1, obj2, code)
487
+ if option.strip() == "all-casing":
488
+ code = re.sub(obj1.lower(), obj2.lower(), code)
489
+ code = re.sub(obj1.upper(), obj2.upper(), code)
490
+
491
+ return code
492
+
493
+
494
+ def find_code_and_splits(object_name: str, base_path: str, buffer: dict = None):
495
+ """Find the code of an object (specified by `object_name`) and split it into blocks.
496
+
497
+ Args:
498
+ object_name (`str`):
499
+ The name of the object, e.g. `transformers.models.bert.modeling_bert.BertAttention` or
500
+ `tests.models.llama.test_modeling_llama.LlamaModelTest.test_config`.
501
+ base_path (`str`):
502
+ The path to the base directory within which the search will be performed. It could be either
503
+ `TRANSFORMERS_PATH` or `MODEL_TEST_PATH`.
504
+ buffer (`dict`, *optional*):
505
+ The buffer used to store the previous results in order to speed up the process.
506
+
507
+ Returns:
508
+ lines (`List[str]`):
509
+ The lines of the whole file where the object is defined.
510
+ code (`str`):
511
+ The object's code.
512
+ code_splits (`List[Tuple[str, int, int]]`):
513
+ `code` splitted into blocks. See `split_code_into_blocks`.
514
+ """
515
+ if buffer is None:
516
+ buffer = {}
517
+
518
+ if (object_name, base_path) in buffer:
519
+ lines, code, code_splits = buffer[(object_name, base_path)]
520
+ else:
521
+ code, (lines, target_start_index, target_end_index) = find_code_in_transformers(
522
+ object_name, base_path=base_path, return_indices=True
523
+ )
524
+ indent = get_indent(code)
525
+
526
+ # Split the code into blocks
527
+ # `indent` is the indent of the class/func definition header, but `code_splits` expects the indent level of the
528
+ # block body.
529
+ code_splits = split_code_into_blocks(
530
+ lines, target_start_index, target_end_index, len(indent) + 4, backtrace=True
531
+ )
532
+ buffer[(object_name, base_path)] = lines, code, code_splits
533
+
534
+ return lines, code, code_splits
535
+
536
+
537
+ _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
538
+ _re_copy_warning_for_test_file = re.compile(r"^(\s*)#\s*Copied from\s+tests\.(\S+\.\S+)\s*($|\S.*$)")
539
+ _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
540
+ _re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
541
+
542
+
543
+ def get_indent(code: str) -> str:
544
+ """
545
+ Find the indent in the first non empty line in a code sample.
546
+
547
+ Args:
548
+ code (`str`): The code to inspect.
549
+
550
+ Returns:
551
+ `str`: The indent looked at (as string).
552
+ """
553
+ lines = code.split("\n")
554
+ idx = 0
555
+ while idx < len(lines) and len(lines[idx]) == 0:
556
+ idx += 1
557
+ if idx < len(lines):
558
+ return re.search(r"^(\s*)\S", lines[idx]).groups()[0]
559
+ return ""
560
+
561
+
562
+ def run_ruff(code, check=False):
563
+ if check:
564
+ command = ["ruff", "check", "-", "--fix", "--exit-zero"]
565
+ else:
566
+ command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
567
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
568
+ stdout, _ = process.communicate(input=code.encode())
569
+ return stdout.decode()
570
+
571
+
572
+ def stylify(code: str) -> str:
573
+ """
574
+ Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`.
575
+ As `ruff` does not provide a python api this cannot be done on the fly.
576
+
577
+ Args:
578
+ code (`str`): The code to format.
579
+
580
+ Returns:
581
+ `str`: The formatted code.
582
+ """
583
+ has_indent = len(get_indent(code)) > 0
584
+ if has_indent:
585
+ code = f"class Bla:\n{code}"
586
+ formatted_code = run_ruff(code)
587
+ return formatted_code[len("class Bla:\n") :] if has_indent else formatted_code
588
+
589
+
590
+ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
591
+ """
592
+ Checks if two version of a code match with the exception of the class/function name.
593
+
594
+ Args:
595
+ observed_code (`str`): The code found.
596
+ theoretical_code (`str`): The code to match.
597
+
598
+ Returns:
599
+ `Optional[int]`: The index of the first line where there is a difference (if any) and `None` if the codes
600
+ match.
601
+ """
602
+ observed_code_header = observed_code.split("\n")[0]
603
+ theoretical_code_header = theoretical_code.split("\n")[0]
604
+
605
+ # Catch the function/class name: it is expected that those do not match.
606
+ _re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
607
+ _re_func_match = re.compile(r"def\s+([^\(]+)\(")
608
+ for re_pattern in [_re_class_match, _re_func_match]:
609
+ if re_pattern.match(observed_code_header) is not None:
610
+ try:
611
+ observed_obj_name = re_pattern.search(observed_code_header).groups()[0]
612
+ except Exception:
613
+ raise ValueError(
614
+ "Tried to split a class or function. It did not work. Error comes from: \n```\n"
615
+ + observed_code_header
616
+ + "\n```\n"
617
+ )
618
+
619
+ try:
620
+ theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
621
+ except Exception:
622
+ raise ValueError(
623
+ "Tried to split a class or function. It did not work. Error comes from: \n```\n"
624
+ + theoretical_code_header
625
+ + "\n```\n"
626
+ )
627
+ theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
628
+
629
+ # Find the first diff. Line 0 is special since we need to compare with the function/class names ignored.
630
+ diff_index = 0
631
+ if theoretical_code_header != observed_code_header:
632
+ return 0
633
+
634
+ diff_index = 1
635
+ for observed_line, theoretical_line in zip(observed_code.split("\n")[1:], theoretical_code.split("\n")[1:]):
636
+ if observed_line != theoretical_line:
637
+ return diff_index
638
+ diff_index += 1
639
+
640
+
641
+ def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = None) -> Optional[List[Tuple[str, int]]]:
642
+ """
643
+ Check if the code commented as a copy in a file matches the original.
644
+
645
+ Args:
646
+ filename (`str`):
647
+ The name of the file to check.
648
+ overwrite (`bool`, *optional*, defaults to `False`):
649
+ Whether or not to overwrite the copies when they don't match.
650
+ buffer (`dict`, *optional*):
651
+ The buffer used to store the previous results in order to speed up the process.
652
+
653
+ Returns:
654
+ `Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
655
+ with the name of the object having a diff and the line number where there is the first diff.
656
+ """
657
+ base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
658
+
659
+ with open(filename, "r", encoding="utf-8", newline="\n") as f:
660
+ lines = f.readlines()
661
+ diffs = []
662
+ line_index = 0
663
+ # Not a for loop cause `lines` is going to change (if `overwrite=True`).
664
+ search_re = _re_copy_warning_for_test_file if filename.startswith("tests") else _re_copy_warning
665
+ while line_index < len(lines):
666
+ search = search_re.search(lines[line_index])
667
+ if search is None:
668
+ line_index += 1
669
+ continue
670
+
671
+ # There is some copied code here, let's retrieve the original.
672
+ indent, object_name, replace_pattern = search.groups()
673
+
674
+ # Find the file lines, the object's code, and its blocks
675
+ try:
676
+ target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits(
677
+ object_name, base_path, buffer=buffer
678
+ )
679
+ except Exception as exc:
680
+ exc.args = (f"Error while trying to find source code for {filename}.\n\n" + str(exc),)
681
+ raise
682
+
683
+ # code replaced by the patterns
684
+ theoretical_code_blocks = OrderedDict()
685
+ for name, start, end in theoretical_code_splits:
686
+ name = replace_code(name, replace_pattern)
687
+ code = "".join(target_lines[start:end])
688
+ code = replace_code(code, replace_pattern)
689
+ theoretical_code_blocks[name] = code
690
+
691
+ theoretical_indent = get_indent(theoretical_code)
692
+
693
+ # `start_index` is the index of the first line (the definition header) after `# Copied from`.
694
+ # (`indent != theoretical_indent` doesn't seem to occur so far, not sure what this case is for.)
695
+ start_index = line_index + 1 if indent == theoretical_indent else line_index
696
+ # enter the block body
697
+ line_index = start_index + 1
698
+
699
+ subcode = "\n".join(theoretical_code.split("\n")[1:])
700
+ indent = get_indent(subcode)
701
+ # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
702
+ # We can't call `find_block_end` directly as there is sth. special `# End copy"` here.
703
+ should_continue = True
704
+ while line_index < len(lines) and should_continue:
705
+ line_index += 1
706
+ if line_index >= len(lines):
707
+ break
708
+ line = lines[line_index]
709
+ # There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
710
+ # used.
711
+ should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
712
+ # `line_index` is outside the block
713
+ # Clean up empty lines at the end (if any).
714
+ while len(lines[line_index - 1]) <= 1:
715
+ line_index -= 1
716
+
717
+ # Split the observed code into blocks
718
+ observed_code_splits = split_code_into_blocks(lines, start_index, line_index, len(indent), backtrace=True)
719
+
720
+ is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ")
721
+ # sanity check
722
+ _sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class, filename=filename)
723
+
724
+ # observed code in a structured way (a dict mapping block names to blocks' code)
725
+ observed_code_blocks = OrderedDict()
726
+ for name, start, end in observed_code_splits:
727
+ code = "".join(lines[start:end])
728
+ observed_code_blocks[name] = code
729
+
730
+ # Below, we change some names in `theoretical_code_blocks` and `observed_code_blocks`. These mappings map the
731
+ # original names to the modified names: this is used to restore the original order of the code blocks.
732
+ name_mappings_1 = {k: k for k in theoretical_code_blocks.keys()}
733
+ name_mappings_2 = {k: k for k in observed_code_blocks.keys()}
734
+
735
+ # Update code blocks' name and content:
736
+ # If `"# Ignore copy"` is found in a block of the observed code:
737
+ # 1. if it's a block only in the observed code --> add it to the theoretical code.
738
+ # 2. if it's also in the theoretical code () --> put its content (body) to the corresponding block under the
739
+ # same name in the theoretical code.
740
+ # In both cases, we change the name to have a prefix `_ignored_` so we know if we can discard them during the
741
+ # comparison.
742
+ ignored_existing_block_index = 0
743
+ ignored_new_block_index = 0
744
+ for name in list(observed_code_blocks.keys()):
745
+ code = observed_code_blocks[name]
746
+ if "# Ignore copy" in code:
747
+ if name in theoretical_code_blocks:
748
+ # in the target --> just copy the content
749
+ del theoretical_code_blocks[name]
750
+ theoretical_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
751
+ name_mappings_1[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
752
+
753
+ del observed_code_blocks[name]
754
+ observed_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
755
+ name_mappings_2[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
756
+ ignored_existing_block_index += 1
757
+ else:
758
+ # not in the target --> add it
759
+ theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
760
+ name_mappings_1[f"_ignored_new_block_{ignored_new_block_index}"] = (
761
+ f"_ignored_new_block_{ignored_new_block_index}"
762
+ )
763
+
764
+ del observed_code_blocks[name]
765
+ observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
766
+ name_mappings_2[name] = f"_ignored_new_block_{ignored_new_block_index}"
767
+ ignored_new_block_index += 1
768
+
769
+ # Respect the original block order:
770
+ # 1. in `theoretical_code_blocks`: the new blocks will follow the existing ones
771
+ # 2. in `observed_code_blocks`: the original order are kept with names modified potentially. This is necessary
772
+ # to compute the correct `diff_index` if `overwrite=True` and there is a diff.
773
+ theoretical_code_blocks = {
774
+ name_mappings_1[orig_name]: theoretical_code_blocks[name_mappings_1[orig_name]]
775
+ for orig_name in name_mappings_1
776
+ }
777
+ observed_code_blocks = {
778
+ name_mappings_2[orig_name]: observed_code_blocks[name_mappings_2[orig_name]]
779
+ for orig_name in name_mappings_2
780
+ }
781
+
782
+ # Ignore the blocks specified to be ignored. This is the version used to check if there is a mismatch
783
+ theoretical_code_blocks_clean = {
784
+ k: v
785
+ for k, v in theoretical_code_blocks.items()
786
+ if not (k.startswith(("_ignored_existing_block_", "_ignored_new_block_")))
787
+ }
788
+ theoretical_code = "".join(list(theoretical_code_blocks_clean.values()))
789
+
790
+ # stylify `theoretical_code` before compare (this is needed only when `replace_pattern` is not empty)
791
+ if replace_pattern:
792
+ theoretical_code = stylify(theoretical_code)
793
+ # Remove `\n\n` in `theoretical_code` before compare (so no empty line)
794
+ while "\n\n" in theoretical_code:
795
+ theoretical_code = theoretical_code.replace("\n\n", "\n")
796
+
797
+ # Compute `observed_code` where we don't include any empty line + keep track the line index between the
798
+ # original/processed `observed_code` so we can have the correct `diff_index`.
799
+ idx_to_orig_idx_mapping_for_observed_code_lines = {}
800
+ idx = -1
801
+ orig_idx = -1
802
+ observed_code = ""
803
+ for name, code in observed_code_blocks.items():
804
+ if code.endswith("\n"):
805
+ code = code[:-1]
806
+ for code_line in code.split("\n"):
807
+ orig_idx += 1
808
+ if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):
809
+ idx += 1
810
+ observed_code += code_line + "\n"
811
+ idx_to_orig_idx_mapping_for_observed_code_lines[idx] = orig_idx
812
+
813
+ # Test for a diff and act accordingly.
814
+ diff_index = check_codes_match(observed_code, theoretical_code)
815
+ if diff_index is not None:
816
+ # switch to the index in the original `observed_code` (i.e. before removing empty lines)
817
+ diff_index = idx_to_orig_idx_mapping_for_observed_code_lines[diff_index]
818
+ diffs.append([object_name, diff_index + start_index + 1])
819
+ if overwrite:
820
+ # `theoretical_code_to_write` is a single string but may have several lines.
821
+ theoretical_code_to_write = stylify("".join(list(theoretical_code_blocks.values())))
822
+ lines = lines[:start_index] + [theoretical_code_to_write] + lines[line_index:]
823
+ # Here we treat it as a single entry in `lines`.
824
+ line_index = start_index + 1
825
+
826
+ if overwrite and len(diffs) > 0:
827
+ # Warn the user a file has been modified.
828
+ print(f"Detected changes, rewriting {filename}.")
829
+ with open(filename, "w", encoding="utf-8", newline="\n") as f:
830
+ f.writelines(lines)
831
+ return diffs
832
+
833
+
834
+ def check_copies(overwrite: bool = False, file: str = None):
835
+ """
836
+ Check every file is copy-consistent with the original. Also check the model list in the main README and other
837
+ READMEs are consistent.
838
+
839
+ Args:
840
+ overwrite (`bool`, *optional*, defaults to `False`):
841
+ Whether or not to overwrite the copies when they don't match.
842
+ file (`bool`, *optional*):
843
+ The path to a specific file to check and/or fix.
844
+ """
845
+ buffer = {}
846
+
847
+ if file is None:
848
+ all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
849
+ all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
850
+ all_files = list(all_files) + list(all_test_files)
851
+ else:
852
+ all_files = [file]
853
+
854
+ diffs = []
855
+ for filename in all_files:
856
+ new_diffs = is_copy_consistent(filename, overwrite, buffer)
857
+ diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
858
+ if not overwrite and len(diffs) > 0:
859
+ diff = "\n".join(diffs)
860
+ raise Exception(
861
+ "Found the following copy inconsistencies:\n"
862
+ + diff
863
+ + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
864
+ )
865
+
866
+
867
+ def check_full_copies(overwrite: bool = False):
868
+ """
869
+ Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent.
870
+
871
+ Args:
872
+ overwrite (`bool`, *optional*, defaults to `False`):
873
+ Whether or not to overwrite the copies when they don't match.
874
+ """
875
+ diffs = []
876
+ for target, source in FULL_COPIES.items():
877
+ with open(source, "r", encoding="utf-8") as f:
878
+ source_code = f.read()
879
+ with open(target, "r", encoding="utf-8") as f:
880
+ target_code = f.read()
881
+ if source_code != target_code:
882
+ if overwrite:
883
+ with open(target, "w", encoding="utf-8") as f:
884
+ print(f"Replacing the content of {target} by the one of {source}.")
885
+ f.write(source_code)
886
+ else:
887
+ diffs.append(f"- {target}: copy does not match {source}.")
888
+
889
+ if not overwrite and len(diffs) > 0:
890
+ diff = "\n".join(diffs)
891
+ raise Exception(
892
+ "Found the following copy inconsistencies:\n"
893
+ + diff
894
+ + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them."
895
+ )
896
+
897
+
898
+ def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str:
899
+ """
900
+ Extracts the model list from a README.
901
+
902
+ Args:
903
+ filename (`str`): The name of the README file to check.
904
+ start_prompt (`str`): The string to look for that introduces the model list.
905
+ end_prompt (`str`): The string to look for that ends the model list.
906
+
907
+ Returns:
908
+ `str`: The model list.
909
+ """
910
+ with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
911
+ lines = f.readlines()
912
+ # Find the start of the list.
913
+ start_index = 0
914
+ while not lines[start_index].startswith(start_prompt):
915
+ start_index += 1
916
+ start_index += 1
917
+
918
+ result = []
919
+ current_line = ""
920
+ end_index = start_index
921
+
922
+ # Keep going until the end of the list.
923
+ while not lines[end_index].startswith(end_prompt):
924
+ if lines[end_index].startswith("1."):
925
+ if len(current_line) > 1:
926
+ result.append(current_line)
927
+ current_line = lines[end_index]
928
+ elif len(lines[end_index]) > 1:
929
+ current_line = f"{current_line[:-1]} {lines[end_index].lstrip()}"
930
+ end_index += 1
931
+ if len(current_line) > 1:
932
+ result.append(current_line)
933
+
934
+ return "".join(result)
935
+
936
+
937
+ def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]:
938
+ """
939
+ Compare the model list from the main README to the one in a localized README.
940
+
941
+ Args:
942
+ model_list (`str`): The model list in the main README.
943
+ localized_model_list (`str`): The model list in one of the localized README.
944
+ format_str (`str`):
945
+ The template for a model entry in the localized README (look at the `format_model_list` in the entries of
946
+ `LOCALIZED_READMES` for examples).
947
+
948
+ Returns:
949
+ `Tuple[bool, str]`: A tuple where the first value indicates if the READMEs match or not, and the second value
950
+ is the correct localized README.
951
+ """
952
+
953
+ def _rep(match):
954
+ title, model_link, paper_affiliations, paper_title_link, paper_authors, supplements = match.groups()
955
+ return format_str.format(
956
+ title=title,
957
+ model_link=model_link,
958
+ paper_affiliations=paper_affiliations,
959
+ paper_title_link=paper_title_link,
960
+ paper_authors=paper_authors,
961
+ supplements=" " + supplements.strip() if len(supplements) != 0 else "",
962
+ )
963
+
964
+ # This regex captures metadata from an English model description, including model title, model link,
965
+ # affiliations of the paper, title of the paper, authors of the paper, and supplemental data (see DistilBERT for
966
+ # example).
967
+ _re_capture_meta = re.compile(
968
+ r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\* \(from ([^)]*)\)[^\[]*([^\)]*\)).*?by (.*?[A-Za-z\*]{2,}?)\. (.*)$"
969
+ )
970
+ # This regex is used to synchronize title link.
971
+ _re_capture_title_link = re.compile(r"\*\*\[([^\]]*)\]\(([^\)]*)\)\*\*")
972
+ # This regex is used to synchronize paper title and link.
973
+ _re_capture_paper_link = re.compile(r" \[([^\]]*)\]\(([^\)]*)\)")
974
+
975
+ if len(localized_model_list) == 0:
976
+ localized_model_index = {}
977
+ else:
978
+ try:
979
+ localized_model_index = {
980
+ re.search(r"\*\*\[([^\]]*)", line).groups()[0]: line
981
+ for line in localized_model_list.strip().split("\n")
982
+ }
983
+ except AttributeError:
984
+ raise AttributeError("A model name in localized READMEs cannot be recognized.")
985
+
986
+ model_keys = [re.search(r"\*\*\[([^\]]*)", line).groups()[0] for line in model_list.strip().split("\n")]
987
+
988
+ # We exclude keys in localized README not in the main one.
989
+ readmes_match = not any(k not in model_keys for k in localized_model_index)
990
+ localized_model_index = {k: v for k, v in localized_model_index.items() if k in model_keys}
991
+
992
+ for model in model_list.strip().split("\n"):
993
+ title, model_link = _re_capture_title_link.search(model).groups()
994
+ if title not in localized_model_index:
995
+ readmes_match = False
996
+ # Add an anchor white space behind a model description string for regex.
997
+ # If metadata cannot be captured, the English version will be directly copied.
998
+ localized_model_index[title] = _re_capture_meta.sub(_rep, model + " ")
999
+ elif _re_fill_pattern.search(localized_model_index[title]) is not None:
1000
+ update = _re_capture_meta.sub(_rep, model + " ")
1001
+ if update != localized_model_index[title]:
1002
+ readmes_match = False
1003
+ localized_model_index[title] = update
1004
+ else:
1005
+ # Synchronize title link
1006
+ converted_model = _re_capture_title_link.sub(
1007
+ f"**[{title}]({model_link})**", localized_model_index[title], count=1
1008
+ )
1009
+
1010
+ # Synchronize paper title and its link (if found)
1011
+ paper_title_link = _re_capture_paper_link.search(model)
1012
+ if paper_title_link is not None:
1013
+ paper_title, paper_link = paper_title_link.groups()
1014
+ converted_model = _re_capture_paper_link.sub(
1015
+ f" [{paper_title}]({paper_link})", converted_model, count=1
1016
+ )
1017
+
1018
+ if converted_model != localized_model_index[title]:
1019
+ readmes_match = False
1020
+ localized_model_index[title] = converted_model
1021
+
1022
+ sorted_index = sorted(localized_model_index.items(), key=lambda x: x[0].lower())
1023
+
1024
+ return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
1025
+
1026
+
1027
+ # Map a model name with the name it has in the README for the check_readme check
1028
+ SPECIAL_MODEL_NAMES = {
1029
+ "Bert Generation": "BERT For Sequence Generation",
1030
+ "BigBird": "BigBird-RoBERTa",
1031
+ "Data2VecAudio": "Data2Vec",
1032
+ "Data2VecText": "Data2Vec",
1033
+ "Data2VecVision": "Data2Vec",
1034
+ "DonutSwin": "Swin Transformer",
1035
+ "Marian": "MarianMT",
1036
+ "MaskFormerSwin": "Swin Transformer",
1037
+ "OpenAI GPT-2": "GPT-2",
1038
+ "OpenAI GPT": "GPT",
1039
+ "Perceiver": "Perceiver IO",
1040
+ "SAM": "Segment Anything",
1041
+ "ViT": "Vision Transformer (ViT)",
1042
+ }
1043
+
1044
+ # Update this list with the models that shouldn't be in the README. This only concerns modular models or those who do
1045
+ # not have an associated paper.
1046
+ MODELS_NOT_IN_README = [
1047
+ "BertJapanese",
1048
+ "Encoder decoder",
1049
+ "FairSeq Machine-Translation",
1050
+ "HerBERT",
1051
+ "RetriBERT",
1052
+ "Speech Encoder decoder",
1053
+ "Speech2Text",
1054
+ "Speech2Text2",
1055
+ "TimmBackbone",
1056
+ "Vision Encoder decoder",
1057
+ "VisionTextDualEncoder",
1058
+ "CLIPVisionModel",
1059
+ "SiglipVisionModel",
1060
+ "ChineseCLIPVisionModel",
1061
+ "VitPoseBackbone",
1062
+ ]
1063
+
1064
+ # Template for new entries to add in the main README when we have missing models.
1065
+ README_TEMPLATE = (
1066
+ "1. **[{model_name}](https://huggingface.co/docs/main/transformers/model_doc/{model_type})** (from "
1067
+ "<FILL INSTITUTION>) released with the paper [<FILL PAPER TITLE>](<FILL ARKIV LINK>) by <FILL AUTHORS>."
1068
+ )
1069
+
1070
+
1071
+ if __name__ == "__main__":
1072
+ parser = argparse.ArgumentParser()
1073
+ parser.add_argument("--file", type=str, default=None, help="A specific file to check and/or fix")
1074
+ parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
1075
+ args = parser.parse_args()
1076
+
1077
+ check_copies(args.fix_and_overwrite, args.file)
1078
+ check_full_copies(args.fix_and_overwrite)
docs/transformers/utils/check_doc_toc.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ This script is responsible for cleaning the model section of the table of content by removing duplicates and sorting
17
+ the entries in alphabetical order.
18
+
19
+ Usage (from the root of the repo):
20
+
21
+ Check that the table of content is properly sorted (used in `make quality`):
22
+
23
+ ```bash
24
+ python utils/check_doc_toc.py
25
+ ```
26
+
27
+ Auto-sort the table of content if it is not properly sorted (used in `make style`):
28
+
29
+ ```bash
30
+ python utils/check_doc_toc.py --fix_and_overwrite
31
+ ```
32
+ """
33
+
34
+ import argparse
35
+ from collections import defaultdict
36
+ from typing import List
37
+
38
+ import yaml
39
+
40
+
41
+ PATH_TO_TOC = "docs/source/en/_toctree.yml"
42
+
43
+
44
+ def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]:
45
+ """
46
+ Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
47
+ and sorting models alphabetically.
48
+
49
+ Args:
50
+ model_doc (`List[dict]`):
51
+ The list of dictionaries extracted from the `_toctree.yml` file for this specific modality.
52
+
53
+ Returns:
54
+ `List[dict]`: List of dictionaries like the input, but cleaned up and sorted.
55
+ """
56
+ counts = defaultdict(int)
57
+ for doc in model_doc:
58
+ counts[doc["local"]] += 1
59
+ duplicates = [key for key, value in counts.items() if value > 1]
60
+
61
+ new_doc = []
62
+ for duplicate_key in duplicates:
63
+ titles = list({doc["title"] for doc in model_doc if doc["local"] == duplicate_key})
64
+ if len(titles) > 1:
65
+ raise ValueError(
66
+ f"{duplicate_key} is present several times in the documentation table of content at "
67
+ "`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
68
+ "others."
69
+ )
70
+ # Only add this once
71
+ new_doc.append({"local": duplicate_key, "title": titles[0]})
72
+
73
+ # Add none duplicate-keys
74
+ new_doc.extend([doc for doc in model_doc if counts[doc["local"]] == 1])
75
+
76
+ # Sort
77
+ return sorted(new_doc, key=lambda s: s["title"].lower())
78
+
79
+
80
+ def check_model_doc(overwrite: bool = False):
81
+ """
82
+ Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model
83
+ API doc) and potentially auto-cleans it.
84
+
85
+ Args:
86
+ overwrite (`bool`, *optional*, defaults to `False`):
87
+ Whether to just check if the TOC is clean or to auto-clean it (when `overwrite=True`).
88
+ """
89
+ with open(PATH_TO_TOC, encoding="utf-8") as f:
90
+ content = yaml.safe_load(f.read())
91
+
92
+ # Get to the API doc
93
+ api_idx = 0
94
+ while content[api_idx]["title"] != "API":
95
+ api_idx += 1
96
+ api_doc = content[api_idx]["sections"]
97
+
98
+ # Then to the model doc
99
+ model_idx = 0
100
+ while api_doc[model_idx]["title"] != "Models":
101
+ model_idx += 1
102
+
103
+ model_doc = api_doc[model_idx]["sections"]
104
+
105
+ # Extract the modalities and clean them one by one.
106
+ modalities_docs = [(idx, section) for idx, section in enumerate(model_doc) if "sections" in section]
107
+ diff = False
108
+ for idx, modality_doc in modalities_docs:
109
+ old_modality_doc = modality_doc["sections"]
110
+ new_modality_doc = clean_model_doc_toc(old_modality_doc)
111
+
112
+ if old_modality_doc != new_modality_doc:
113
+ diff = True
114
+ if overwrite:
115
+ model_doc[idx]["sections"] = new_modality_doc
116
+
117
+ if diff:
118
+ if overwrite:
119
+ api_doc[model_idx]["sections"] = model_doc
120
+ content[api_idx]["sections"] = api_doc
121
+ with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
122
+ f.write(yaml.dump(content, allow_unicode=True))
123
+ else:
124
+ raise ValueError(
125
+ "The model doc part of the table of content is not properly sorted, run `make style` to fix this."
126
+ )
127
+
128
+
129
+ if __name__ == "__main__":
130
+ parser = argparse.ArgumentParser()
131
+ parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
132
+ args = parser.parse_args()
133
+
134
+ check_model_doc(args.fix_and_overwrite)
docs/transformers/utils/check_docstrings.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Utility that checks all docstrings of public objects have an argument section matching their signature.
17
+
18
+ Use from the root of the repo with:
19
+
20
+ ```bash
21
+ python utils/check_docstrings.py
22
+ ```
23
+
24
+ for a check that will error in case of inconsistencies (used by `make repo-consistency`).
25
+
26
+ To auto-fix issues run:
27
+
28
+ ```bash
29
+ python utils/check_docstrings.py --fix_and_overwrite
30
+ ```
31
+
32
+ which is used by `make fix-copies` (note that this fills what it cans, you might have to manually fill information
33
+ like argument descriptions).
34
+ """
35
+
36
+ import argparse
37
+ import ast
38
+ import enum
39
+ import inspect
40
+ import operator as op
41
+ import re
42
+ from pathlib import Path
43
+ from typing import Any, Optional, Tuple, Union
44
+
45
+ from check_repo import ignore_undocumented
46
+ from git import Repo
47
+
48
+ from transformers.utils import direct_transformers_import
49
+
50
+
51
+ PATH_TO_REPO = Path(__file__).parent.parent.resolve()
52
+ PATH_TO_TRANSFORMERS = Path("src").resolve() / "transformers"
53
+
54
+ # This is to make sure the transformers module imported is the one in the repo.
55
+ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
56
+
57
+ OPTIONAL_KEYWORD = "*optional*"
58
+ # Re pattern that catches args blocks in docstrings (with all variation around the name supported).
59
+ _re_args = re.compile(r"^\s*(Args?|Arguments?|Attributes?|Params?|Parameters?):\s*$")
60
+ # Re pattern that parses the start of an arg block: catches <name> (<description>) in those lines.
61
+ _re_parse_arg = re.compile(r"^(\s*)(\S+)\s+\((.+)\)(?:\:|$)")
62
+ # Re pattern that parses the end of a description of an arg (catches the default in *optional*, defaults to xxx).
63
+ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
64
+
65
+
66
+ # This is a temporary list of objects to ignore while we progressively fix them. Do not add anything here, fix the
67
+ # docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
68
+ # line before the docstring.
69
+ OBJECTS_TO_IGNORE = [
70
+ "Llama4Processor",
71
+ # Deprecated
72
+ "InputExample",
73
+ "InputFeatures",
74
+ # Signature is *args/**kwargs
75
+ "TFSequenceSummary",
76
+ "TFBertTokenizer",
77
+ "TFGPT2Tokenizer",
78
+ # Missing arguments in the docstring
79
+ "ASTFeatureExtractor",
80
+ "AlbertModel",
81
+ "AlbertTokenizerFast",
82
+ "AlignTextModel",
83
+ "AlignVisionConfig",
84
+ "AudioClassificationPipeline",
85
+ "AutoformerConfig",
86
+ "AutomaticSpeechRecognitionPipeline",
87
+ "BarkCoarseConfig",
88
+ "BarkConfig",
89
+ "BarkFineConfig",
90
+ "BarkSemanticConfig",
91
+ "BartConfig",
92
+ "BartTokenizerFast",
93
+ "BarthezTokenizerFast",
94
+ "BeitModel",
95
+ "BertConfig",
96
+ "BertJapaneseTokenizer",
97
+ "BertModel",
98
+ "BertTokenizerFast",
99
+ "BigBirdConfig",
100
+ "BigBirdForQuestionAnswering",
101
+ "BigBirdModel",
102
+ "BigBirdPegasusConfig",
103
+ "BigBirdTokenizerFast",
104
+ "BitImageProcessor",
105
+ "BlenderbotConfig",
106
+ "BlenderbotSmallConfig",
107
+ "BlenderbotSmallTokenizerFast",
108
+ "BlenderbotTokenizerFast",
109
+ "Blip2VisionConfig",
110
+ "BlipTextConfig",
111
+ "BlipVisionConfig",
112
+ "BloomConfig",
113
+ "BloomTokenizerFast",
114
+ "BridgeTowerTextConfig",
115
+ "BridgeTowerVisionConfig",
116
+ "BrosModel",
117
+ "CamembertConfig",
118
+ "CamembertModel",
119
+ "CamembertTokenizerFast",
120
+ "CanineModel",
121
+ "CanineTokenizer",
122
+ "ChineseCLIPTextModel",
123
+ "ClapTextConfig",
124
+ "ConditionalDetrConfig",
125
+ "ConditionalDetrImageProcessor",
126
+ "ConvBertConfig",
127
+ "ConvBertTokenizerFast",
128
+ "ConvNextConfig",
129
+ "ConvNextV2Config",
130
+ "CpmAntTokenizer",
131
+ "CvtConfig",
132
+ "CvtModel",
133
+ "DeiTImageProcessor",
134
+ "DPRReaderTokenizer",
135
+ "DPRReaderTokenizerFast",
136
+ "DPTModel",
137
+ "Data2VecAudioConfig",
138
+ "Data2VecTextConfig",
139
+ "Data2VecTextModel",
140
+ "Data2VecVisionModel",
141
+ "DataCollatorForLanguageModeling",
142
+ "DebertaConfig",
143
+ "DebertaV2Config",
144
+ "DebertaV2Tokenizer",
145
+ "DebertaV2TokenizerFast",
146
+ "DecisionTransformerConfig",
147
+ "DeformableDetrConfig",
148
+ "DeformableDetrImageProcessor",
149
+ "DeiTModel",
150
+ "DepthEstimationPipeline",
151
+ "DetaConfig",
152
+ "DetaImageProcessor",
153
+ "DetrConfig",
154
+ "DetrImageProcessor",
155
+ "DinatModel",
156
+ "DistilBertConfig",
157
+ "DistilBertTokenizerFast",
158
+ "DocumentQuestionAnsweringPipeline",
159
+ "DonutSwinModel",
160
+ "EarlyStoppingCallback",
161
+ "EfficientFormerConfig",
162
+ "EfficientFormerImageProcessor",
163
+ "EfficientNetConfig",
164
+ "ElectraConfig",
165
+ "ElectraTokenizerFast",
166
+ "EncoderDecoderModel",
167
+ "ErnieMModel",
168
+ "ErnieModel",
169
+ "ErnieMTokenizer",
170
+ "EsmConfig",
171
+ "EsmModel",
172
+ "FlaxAlbertForMaskedLM",
173
+ "FlaxAlbertForMultipleChoice",
174
+ "FlaxAlbertForPreTraining",
175
+ "FlaxAlbertForQuestionAnswering",
176
+ "FlaxAlbertForSequenceClassification",
177
+ "FlaxAlbertForTokenClassification",
178
+ "FlaxAlbertModel",
179
+ "FlaxBartForCausalLM",
180
+ "FlaxBartForConditionalGeneration",
181
+ "FlaxBartForQuestionAnswering",
182
+ "FlaxBartForSequenceClassification",
183
+ "FlaxBartModel",
184
+ "FlaxBeitForImageClassification",
185
+ "FlaxBeitForMaskedImageModeling",
186
+ "FlaxBeitModel",
187
+ "FlaxBertForCausalLM",
188
+ "FlaxBertForMaskedLM",
189
+ "FlaxBertForMultipleChoice",
190
+ "FlaxBertForNextSentencePrediction",
191
+ "FlaxBertForPreTraining",
192
+ "FlaxBertForQuestionAnswering",
193
+ "FlaxBertForSequenceClassification",
194
+ "FlaxBertForTokenClassification",
195
+ "FlaxBertModel",
196
+ "FlaxBigBirdForCausalLM",
197
+ "FlaxBigBirdForMaskedLM",
198
+ "FlaxBigBirdForMultipleChoice",
199
+ "FlaxBigBirdForPreTraining",
200
+ "FlaxBigBirdForQuestionAnswering",
201
+ "FlaxBigBirdForSequenceClassification",
202
+ "FlaxBigBirdForTokenClassification",
203
+ "FlaxBigBirdModel",
204
+ "FlaxBlenderbotForConditionalGeneration",
205
+ "FlaxBlenderbotModel",
206
+ "FlaxBlenderbotSmallForConditionalGeneration",
207
+ "FlaxBlenderbotSmallModel",
208
+ "FlaxBloomForCausalLM",
209
+ "FlaxBloomModel",
210
+ "FlaxCLIPModel",
211
+ "FlaxDinov2ForImageClassification",
212
+ "FlaxDinov2Model",
213
+ "FlaxDistilBertForMaskedLM",
214
+ "FlaxDistilBertForMultipleChoice",
215
+ "FlaxDistilBertForQuestionAnswering",
216
+ "FlaxDistilBertForSequenceClassification",
217
+ "FlaxDistilBertForTokenClassification",
218
+ "FlaxDistilBertModel",
219
+ "FlaxElectraForCausalLM",
220
+ "FlaxElectraForMaskedLM",
221
+ "FlaxElectraForMultipleChoice",
222
+ "FlaxElectraForPreTraining",
223
+ "FlaxElectraForQuestionAnswering",
224
+ "FlaxElectraForSequenceClassification",
225
+ "FlaxElectraForTokenClassification",
226
+ "FlaxElectraModel",
227
+ "FlaxEncoderDecoderModel",
228
+ "FlaxGPT2LMHeadModel",
229
+ "FlaxGPT2Model",
230
+ "FlaxGPTJForCausalLM",
231
+ "FlaxGPTJModel",
232
+ "FlaxGPTNeoForCausalLM",
233
+ "FlaxGPTNeoModel",
234
+ "FlaxLlamaForCausalLM",
235
+ "FlaxLlamaModel",
236
+ "FlaxGemmaForCausalLM",
237
+ "FlaxGemmaModel",
238
+ "FlaxMBartForConditionalGeneration",
239
+ "FlaxMBartForQuestionAnswering",
240
+ "FlaxMBartForSequenceClassification",
241
+ "FlaxMBartModel",
242
+ "FlaxMarianMTModel",
243
+ "FlaxMarianModel",
244
+ "FlaxMistralForCausalLM",
245
+ "FlaxMistralModel",
246
+ "FlaxOPTForCausalLM",
247
+ "FlaxPegasusForConditionalGeneration",
248
+ "FlaxPegasusModel",
249
+ "FlaxRegNetForImageClassification",
250
+ "FlaxRegNetModel",
251
+ "FlaxResNetForImageClassification",
252
+ "FlaxResNetModel",
253
+ "FlaxRoFormerForMaskedLM",
254
+ "FlaxRoFormerForMultipleChoice",
255
+ "FlaxRoFormerForQuestionAnswering",
256
+ "FlaxRoFormerForSequenceClassification",
257
+ "FlaxRoFormerForTokenClassification",
258
+ "FlaxRoFormerModel",
259
+ "FlaxRobertaForCausalLM",
260
+ "FlaxRobertaForMaskedLM",
261
+ "FlaxRobertaForMultipleChoice",
262
+ "FlaxRobertaForQuestionAnswering",
263
+ "FlaxRobertaForSequenceClassification",
264
+ "FlaxRobertaForTokenClassification",
265
+ "FlaxRobertaModel",
266
+ "FlaxRobertaPreLayerNormForCausalLM",
267
+ "FlaxRobertaPreLayerNormForMaskedLM",
268
+ "FlaxRobertaPreLayerNormForMultipleChoice",
269
+ "FlaxRobertaPreLayerNormForQuestionAnswering",
270
+ "FlaxRobertaPreLayerNormForSequenceClassification",
271
+ "FlaxRobertaPreLayerNormForTokenClassification",
272
+ "FlaxRobertaPreLayerNormModel",
273
+ "FlaxSpeechEncoderDecoderModel",
274
+ "FlaxViTForImageClassification",
275
+ "FlaxViTModel",
276
+ "FlaxVisionEncoderDecoderModel",
277
+ "FlaxVisionTextDualEncoderModel",
278
+ "FlaxWav2Vec2ForCTC",
279
+ "FlaxWav2Vec2ForPreTraining",
280
+ "FlaxWav2Vec2Model",
281
+ "FlaxWhisperForAudioClassification",
282
+ "FlaxWhisperForConditionalGeneration",
283
+ "FlaxWhisperModel",
284
+ "FlaxWhisperTimeStampLogitsProcessor",
285
+ "FlaxXGLMForCausalLM",
286
+ "FlaxXGLMModel",
287
+ "FlaxXLMRobertaForCausalLM",
288
+ "FlaxXLMRobertaForMaskedLM",
289
+ "FlaxXLMRobertaForMultipleChoice",
290
+ "FlaxXLMRobertaForQuestionAnswering",
291
+ "FlaxXLMRobertaForSequenceClassification",
292
+ "FlaxXLMRobertaForTokenClassification",
293
+ "FlaxXLMRobertaModel",
294
+ "FNetConfig",
295
+ "FNetModel",
296
+ "FNetTokenizerFast",
297
+ "FSMTConfig",
298
+ "FeatureExtractionPipeline",
299
+ "FillMaskPipeline",
300
+ "FlaubertConfig",
301
+ "FlavaConfig",
302
+ "FlavaForPreTraining",
303
+ "FlavaImageModel",
304
+ "FlavaImageProcessor",
305
+ "FlavaMultimodalModel",
306
+ "FlavaTextConfig",
307
+ "FlavaTextModel",
308
+ "FocalNetModel",
309
+ "FunnelTokenizerFast",
310
+ "GPTBigCodeConfig",
311
+ "GPTJConfig",
312
+ "GPTNeoXConfig",
313
+ "GPTNeoXJapaneseConfig",
314
+ "GPTNeoXTokenizerFast",
315
+ "GPTSanJapaneseConfig",
316
+ "GitConfig",
317
+ "GitVisionConfig",
318
+ "GraphormerConfig",
319
+ "GroupViTTextConfig",
320
+ "GroupViTVisionConfig",
321
+ "HerbertTokenizerFast",
322
+ "HubertConfig",
323
+ "HubertForCTC",
324
+ "IBertConfig",
325
+ "IBertModel",
326
+ "IdeficsConfig",
327
+ "IdeficsProcessor",
328
+ "IJepaModel",
329
+ "ImageClassificationPipeline",
330
+ "ImageFeatureExtractionPipeline",
331
+ "ImageGPTConfig",
332
+ "ImageSegmentationPipeline",
333
+ "ImageTextToTextPipeline",
334
+ "ImageToImagePipeline",
335
+ "ImageToTextPipeline",
336
+ "InformerConfig",
337
+ "JukeboxPriorConfig",
338
+ "JukeboxTokenizer",
339
+ "LEDConfig",
340
+ "LEDTokenizerFast",
341
+ "LayoutLMForQuestionAnswering",
342
+ "LayoutLMTokenizerFast",
343
+ "LayoutLMv2Config",
344
+ "LayoutLMv2ForQuestionAnswering",
345
+ "LayoutLMv2TokenizerFast",
346
+ "LayoutLMv3Config",
347
+ "LayoutLMv3ImageProcessor",
348
+ "LayoutLMv3TokenizerFast",
349
+ "LayoutXLMTokenizerFast",
350
+ "LevitConfig",
351
+ "LiltConfig",
352
+ "LiltModel",
353
+ "LongT5Config",
354
+ "LongformerConfig",
355
+ "LongformerModel",
356
+ "LongformerTokenizerFast",
357
+ "LukeModel",
358
+ "LukeTokenizer",
359
+ "LxmertTokenizerFast",
360
+ "M2M100Config",
361
+ "M2M100Tokenizer",
362
+ "MarkupLMProcessor",
363
+ "MaskGenerationPipeline",
364
+ "MBart50TokenizerFast",
365
+ "MBartConfig",
366
+ "MCTCTFeatureExtractor",
367
+ "MPNetConfig",
368
+ "MPNetModel",
369
+ "MPNetTokenizerFast",
370
+ "MT5Config",
371
+ "MT5TokenizerFast",
372
+ "MarianConfig",
373
+ "MarianTokenizer",
374
+ "MarkupLMConfig",
375
+ "MarkupLMModel",
376
+ "MarkupLMTokenizer",
377
+ "MarkupLMTokenizerFast",
378
+ "Mask2FormerConfig",
379
+ "MaskFormerConfig",
380
+ "MaxTimeCriteria",
381
+ "MegaConfig",
382
+ "MegaModel",
383
+ "MegatronBertConfig",
384
+ "MegatronBertForPreTraining",
385
+ "MegatronBertModel",
386
+ "MLCDVisionConfig",
387
+ "MobileBertConfig",
388
+ "MobileBertModel",
389
+ "MobileBertTokenizerFast",
390
+ "MobileNetV1ImageProcessor",
391
+ "MobileNetV1Model",
392
+ "MobileNetV2ImageProcessor",
393
+ "MobileNetV2Model",
394
+ "MobileViTModel",
395
+ "MobileViTV2Model",
396
+ "MLukeTokenizer",
397
+ "MraConfig",
398
+ "MusicgenDecoderConfig",
399
+ "MusicgenForConditionalGeneration",
400
+ "MusicgenMelodyForConditionalGeneration",
401
+ "MvpConfig",
402
+ "MvpTokenizerFast",
403
+ "MT5Tokenizer",
404
+ "NatModel",
405
+ "NerPipeline",
406
+ "NezhaConfig",
407
+ "NezhaModel",
408
+ "NllbMoeConfig",
409
+ "NllbTokenizer",
410
+ "NllbTokenizerFast",
411
+ "NystromformerConfig",
412
+ "OPTConfig",
413
+ "ObjectDetectionPipeline",
414
+ "OneFormerProcessor",
415
+ "OpenAIGPTTokenizerFast",
416
+ "OpenLlamaConfig",
417
+ "PLBartConfig",
418
+ "PegasusConfig",
419
+ "PegasusTokenizer",
420
+ "PegasusTokenizerFast",
421
+ "PegasusXConfig",
422
+ "PerceiverImageProcessor",
423
+ "PerceiverModel",
424
+ "PerceiverTokenizer",
425
+ "PersimmonConfig",
426
+ "Pipeline",
427
+ "Pix2StructConfig",
428
+ "Pix2StructTextConfig",
429
+ "PLBartTokenizer",
430
+ "Pop2PianoConfig",
431
+ "PreTrainedTokenizer",
432
+ "PreTrainedTokenizerBase",
433
+ "PreTrainedTokenizerFast",
434
+ "PrefixConstrainedLogitsProcessor",
435
+ "ProphetNetConfig",
436
+ "QDQBertConfig",
437
+ "QDQBertModel",
438
+ "QuestionAnsweringPipeline",
439
+ "RagConfig",
440
+ "RagModel",
441
+ "RagRetriever",
442
+ "RagSequenceForGeneration",
443
+ "RagTokenForGeneration",
444
+ "RealmConfig",
445
+ "RealmForOpenQA",
446
+ "RealmScorer",
447
+ "RealmTokenizerFast",
448
+ "ReformerConfig",
449
+ "ReformerTokenizerFast",
450
+ "RegNetConfig",
451
+ "RemBertConfig",
452
+ "RemBertModel",
453
+ "RemBertTokenizer",
454
+ "RemBertTokenizerFast",
455
+ "RetriBertConfig",
456
+ "RetriBertTokenizerFast",
457
+ "RoCBertConfig",
458
+ "RoCBertModel",
459
+ "RoCBertTokenizer",
460
+ "RoFormerConfig",
461
+ "RobertaConfig",
462
+ "RobertaModel",
463
+ "RobertaPreLayerNormConfig",
464
+ "RobertaPreLayerNormModel",
465
+ "RobertaTokenizerFast",
466
+ "SEWConfig",
467
+ "SEWDConfig",
468
+ "SEWDForCTC",
469
+ "SEWForCTC",
470
+ "SamConfig",
471
+ "SamPromptEncoderConfig",
472
+ "SeamlessM4TConfig", # use of unconventional markdown
473
+ "SeamlessM4Tv2Config", # use of unconventional markdown
474
+ "Seq2SeqTrainingArguments",
475
+ "SpecialTokensMixin",
476
+ "Speech2Text2Config",
477
+ "Speech2Text2Tokenizer",
478
+ "Speech2TextTokenizer",
479
+ "SpeechEncoderDecoderModel",
480
+ "SpeechT5Config",
481
+ "SpeechT5Model",
482
+ "SplinterConfig",
483
+ "SplinterTokenizerFast",
484
+ "SqueezeBertTokenizerFast",
485
+ "SummarizationPipeline",
486
+ "Swin2SRImageProcessor",
487
+ "Swinv2Model",
488
+ "SwitchTransformersConfig",
489
+ "T5Config",
490
+ "T5Tokenizer",
491
+ "T5TokenizerFast",
492
+ "TableQuestionAnsweringPipeline",
493
+ "TableTransformerConfig",
494
+ "TapasConfig",
495
+ "TapasModel",
496
+ "TapasTokenizer",
497
+ "Text2TextGenerationPipeline",
498
+ "TextClassificationPipeline",
499
+ "TextGenerationPipeline",
500
+ "TFBartForConditionalGeneration",
501
+ "TFBartForSequenceClassification",
502
+ "TFBartModel",
503
+ "TFBertModel",
504
+ "TFConvNextModel",
505
+ "TFData2VecVisionModel",
506
+ "TFDeiTModel",
507
+ "TFEncoderDecoderModel",
508
+ "TFEsmModel",
509
+ "TFMobileViTModel",
510
+ "TFRagModel",
511
+ "TFRagSequenceForGeneration",
512
+ "TFRagTokenForGeneration",
513
+ "TFRepetitionPenaltyLogitsProcessor",
514
+ "TFSwinModel",
515
+ "TFViTModel",
516
+ "TFVisionEncoderDecoderModel",
517
+ "TFVisionTextDualEncoderModel",
518
+ "TFXGLMForCausalLM",
519
+ "TFXGLMModel",
520
+ "TimeSeriesTransformerConfig",
521
+ "TokenClassificationPipeline",
522
+ "TrOCRConfig",
523
+ "Phi4MultimodalProcessor",
524
+ "TrainerState",
525
+ "TrainingArguments",
526
+ "TrajectoryTransformerConfig",
527
+ "TranslationPipeline",
528
+ "TvltImageProcessor",
529
+ "UMT5Config",
530
+ "UperNetConfig",
531
+ "UperNetForSemanticSegmentation",
532
+ "ViTHybridImageProcessor",
533
+ "ViTHybridModel",
534
+ "ViTMSNModel",
535
+ "ViTModel",
536
+ "VideoClassificationPipeline",
537
+ "ViltConfig",
538
+ "ViltForImagesAndTextClassification",
539
+ "ViltModel",
540
+ "VisionEncoderDecoderModel",
541
+ "VisionTextDualEncoderModel",
542
+ "VisualBertConfig",
543
+ "VisualBertModel",
544
+ "VisualQuestionAnsweringPipeline",
545
+ "VitMatteForImageMatting",
546
+ "VitsTokenizer",
547
+ "VivitModel",
548
+ "Wav2Vec2BertForCTC",
549
+ "Wav2Vec2CTCTokenizer",
550
+ "Wav2Vec2Config",
551
+ "Wav2Vec2ConformerConfig",
552
+ "Wav2Vec2ConformerForCTC",
553
+ "Wav2Vec2FeatureExtractor",
554
+ "Wav2Vec2PhonemeCTCTokenizer",
555
+ "WavLMConfig",
556
+ "WavLMForCTC",
557
+ "WhisperConfig",
558
+ "WhisperFeatureExtractor",
559
+ "WhisperForAudioClassification",
560
+ "XCLIPTextConfig",
561
+ "XCLIPVisionConfig",
562
+ "XGLMConfig",
563
+ "XGLMModel",
564
+ "XGLMTokenizerFast",
565
+ "XLMConfig",
566
+ "XLMProphetNetConfig",
567
+ "XLMRobertaConfig",
568
+ "XLMRobertaModel",
569
+ "XLMRobertaTokenizerFast",
570
+ "XLMRobertaXLConfig",
571
+ "XLMRobertaXLModel",
572
+ "XLNetConfig",
573
+ "XLNetTokenizerFast",
574
+ "XmodConfig",
575
+ "XmodModel",
576
+ "YolosImageProcessor",
577
+ "YolosModel",
578
+ "YosoConfig",
579
+ "ZeroShotAudioClassificationPipeline",
580
+ "ZeroShotClassificationPipeline",
581
+ "ZeroShotImageClassificationPipeline",
582
+ "ZeroShotObjectDetectionPipeline",
583
+ "Llama4TextConfig",
584
+ ]
585
+
586
+ # Supported math operations when interpreting the value of defaults.
587
+ MATH_OPERATORS = {
588
+ ast.Add: op.add,
589
+ ast.Sub: op.sub,
590
+ ast.Mult: op.mul,
591
+ ast.Div: op.truediv,
592
+ ast.Pow: op.pow,
593
+ ast.BitXor: op.xor,
594
+ ast.USub: op.neg,
595
+ }
596
+
597
+
598
+ def find_indent(line: str) -> int:
599
+ """
600
+ Returns the number of spaces that start a line indent.
601
+ """
602
+ search = re.search(r"^(\s*)(?:\S|$)", line)
603
+ if search is None:
604
+ return 0
605
+ return len(search.groups()[0])
606
+
607
+
608
+ def stringify_default(default: Any) -> str:
609
+ """
610
+ Returns the string representation of a default value, as used in docstring: numbers are left as is, all other
611
+ objects are in backtiks.
612
+
613
+ Args:
614
+ default (`Any`): The default value to process
615
+
616
+ Returns:
617
+ `str`: The string representation of that default.
618
+ """
619
+ if isinstance(default, bool):
620
+ # We need to test for bool first as a bool passes isinstance(xxx, (int, float))
621
+ return f"`{default}`"
622
+ elif isinstance(default, enum.Enum):
623
+ # We need to test for enum first as an enum with int values will pass isinstance(xxx, (int, float))
624
+ return f"`{str(default)}`"
625
+ elif isinstance(default, int):
626
+ return str(default)
627
+ elif isinstance(default, float):
628
+ result = str(default)
629
+ return str(round(default, 2)) if len(result) > 6 else result
630
+ elif isinstance(default, str):
631
+ return str(default) if default.isnumeric() else f'`"{default}"`'
632
+ elif isinstance(default, type):
633
+ return f"`{default.__name__}`"
634
+ else:
635
+ return f"`{default}`"
636
+
637
+
638
+ def eval_math_expression(expression: str) -> Optional[Union[float, int]]:
639
+ # Mainly taken from the excellent https://stackoverflow.com/a/9558001
640
+ """
641
+ Evaluate (safely) a mathematial expression and returns its value.
642
+
643
+ Args:
644
+ expression (`str`): The expression to evaluate.
645
+
646
+ Returns:
647
+ `Optional[Union[float, int]]`: Returns `None` if the evaluation fails in any way and the value computed
648
+ otherwise.
649
+
650
+ Example:
651
+
652
+ ```py
653
+ >>> eval_expr('2^6')
654
+ 4
655
+ >>> eval_expr('2**6')
656
+ 64
657
+ >>> eval_expr('1 + 2*3**(4^5) / (6 + -7)')
658
+ -5.0
659
+ ```
660
+ """
661
+ try:
662
+ return eval_node(ast.parse(expression, mode="eval").body)
663
+ except TypeError:
664
+ return
665
+
666
+
667
+ def eval_node(node):
668
+ if isinstance(node, ast.Num): # <number>
669
+ return node.n
670
+ elif isinstance(node, ast.BinOp): # <left> <operator> <right>
671
+ return MATH_OPERATORS[type(node.op)](eval_node(node.left), eval_node(node.right))
672
+ elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
673
+ return MATH_OPERATORS[type(node.op)](eval_node(node.operand))
674
+ else:
675
+ raise TypeError(node)
676
+
677
+
678
+ def replace_default_in_arg_description(description: str, default: Any) -> str:
679
+ """
680
+ Catches the default value in the description of an argument inside a docstring and replaces it by the value passed.
681
+
682
+ Args:
683
+ description (`str`): The description of an argument in a docstring to process.
684
+ default (`Any`): The default value that would be in the docstring of that argument.
685
+
686
+ Returns:
687
+ `str`: The description updated with the new default value.
688
+ """
689
+ # Lots of docstrings have `optional` or **opational** instead of *optional* so we do this fix here.
690
+ description = description.replace("`optional`", OPTIONAL_KEYWORD)
691
+ description = description.replace("**optional**", OPTIONAL_KEYWORD)
692
+ if default is inspect._empty:
693
+ # No default, make sure the description doesn't have any either
694
+ idx = description.find(OPTIONAL_KEYWORD)
695
+ if idx != -1:
696
+ description = description[:idx].rstrip()
697
+ if description.endswith(","):
698
+ description = description[:-1].rstrip()
699
+ elif default is None:
700
+ # Default None are not written, we just set `*optional*`. If there is default that is not None specified in the
701
+ # description, we do not erase it (as sometimes we set the default to `None` because the default is a mutable
702
+ # object).
703
+ idx = description.find(OPTIONAL_KEYWORD)
704
+ if idx == -1:
705
+ description = f"{description}, {OPTIONAL_KEYWORD}"
706
+ elif re.search(r"defaults to `?None`?", description) is not None:
707
+ len_optional = len(OPTIONAL_KEYWORD)
708
+ description = description[: idx + len_optional]
709
+ else:
710
+ str_default = None
711
+ # For numbers we may have a default that is given by a math operation (1/255 is really popular). We don't
712
+ # want to replace those by their actual values.
713
+ if isinstance(default, (int, float)) and re.search("defaults to `?(.*?)(?:`|$)", description) is not None:
714
+ # Grab the default and evaluate it.
715
+ current_default = re.search("defaults to `?(.*?)(?:`|$)", description).groups()[0]
716
+ if default == eval_math_expression(current_default):
717
+ try:
718
+ # If it can be directly converted to the type of the default, it's a simple value
719
+ str_default = str(type(default)(current_default))
720
+ except Exception:
721
+ # Otherwise there is a math operator so we add a code block.
722
+ str_default = f"`{current_default}`"
723
+ elif isinstance(default, enum.Enum) and default.name == current_default.split(".")[-1]:
724
+ # When the default is an Enum (this is often the case for PIL.Image.Resampling), and the docstring
725
+ # matches the enum name, keep the existing docstring rather than clobbering it with the enum value.
726
+ str_default = f"`{current_default}`"
727
+
728
+ if str_default is None:
729
+ str_default = stringify_default(default)
730
+ # Make sure default match
731
+ if OPTIONAL_KEYWORD not in description:
732
+ description = f"{description}, {OPTIONAL_KEYWORD}, defaults to {str_default}"
733
+ elif _re_parse_description.search(description) is None:
734
+ idx = description.find(OPTIONAL_KEYWORD)
735
+ len_optional = len(OPTIONAL_KEYWORD)
736
+ description = f"{description[: idx + len_optional]}, defaults to {str_default}"
737
+ else:
738
+ description = _re_parse_description.sub(rf"*optional*, defaults to {str_default}", description)
739
+
740
+ return description
741
+
742
+
743
+ def get_default_description(arg: inspect.Parameter) -> str:
744
+ """
745
+ Builds a default description for a parameter that was not documented.
746
+
747
+ Args:
748
+ arg (`inspect.Parameter`): The argument in the signature to generate a description for.
749
+
750
+ Returns:
751
+ `str`: The description.
752
+ """
753
+ if arg.annotation is inspect._empty:
754
+ arg_type = "<fill_type>"
755
+ elif hasattr(arg.annotation, "__name__"):
756
+ arg_type = arg.annotation.__name__
757
+ else:
758
+ arg_type = str(arg.annotation)
759
+
760
+ if arg.default is inspect._empty:
761
+ return f"`{arg_type}`"
762
+ elif arg.default is None:
763
+ return f"`{arg_type}`, {OPTIONAL_KEYWORD}"
764
+ else:
765
+ str_default = stringify_default(arg.default)
766
+ return f"`{arg_type}`, {OPTIONAL_KEYWORD}, defaults to {str_default}"
767
+
768
+
769
+ def find_source_file(obj: Any) -> Path:
770
+ """
771
+ Finds the source file of an object.
772
+
773
+ Args:
774
+ obj (`Any`): The object whose source file we are looking for.
775
+
776
+ Returns:
777
+ `Path`: The source file.
778
+ """
779
+ module = obj.__module__
780
+ obj_file = PATH_TO_TRANSFORMERS
781
+ for part in module.split(".")[1:]:
782
+ obj_file = obj_file / part
783
+ return obj_file.with_suffix(".py")
784
+
785
+
786
+ def match_docstring_with_signature(obj: Any) -> Optional[Tuple[str, str]]:
787
+ """
788
+ Matches the docstring of an object with its signature.
789
+
790
+ Args:
791
+ obj (`Any`): The object to process.
792
+
793
+ Returns:
794
+ `Optional[Tuple[str, str]]`: Returns `None` if there is no docstring or no parameters documented in the
795
+ docstring, otherwise returns a tuple of two strings: the current documentation of the arguments in the
796
+ docstring and the one matched with the signature.
797
+ """
798
+ if len(getattr(obj, "__doc__", "")) == 0:
799
+ # Nothing to do, there is no docstring.
800
+ return
801
+
802
+ # Read the docstring in the source code to see if there is a special command to ignore this object.
803
+ try:
804
+ source, _ = inspect.getsourcelines(obj)
805
+ except OSError:
806
+ source = []
807
+
808
+ idx = 0
809
+ while idx < len(source) and '"""' not in source[idx]:
810
+ idx += 1
811
+
812
+ ignore_order = False
813
+ if idx < len(source):
814
+ line_before_docstring = source[idx - 1]
815
+ if re.search(r"^\s*#\s*no-format\s*$", line_before_docstring):
816
+ # This object is ignored
817
+ return
818
+ elif re.search(r"^\s*#\s*ignore-order\s*$", line_before_docstring):
819
+ ignore_order = True
820
+
821
+ # Read the signature
822
+ signature = inspect.signature(obj).parameters
823
+
824
+ obj_doc_lines = obj.__doc__.split("\n")
825
+ # Get to the line where we start documenting arguments
826
+ idx = 0
827
+ while idx < len(obj_doc_lines) and _re_args.search(obj_doc_lines[idx]) is None:
828
+ idx += 1
829
+
830
+ if idx == len(obj_doc_lines):
831
+ # Nothing to do, no parameters are documented.
832
+ return
833
+
834
+ if "kwargs" in signature and signature["kwargs"].annotation != inspect._empty:
835
+ # Inspecting signature with typed kwargs is not supported yet.
836
+ return
837
+
838
+ indent = find_indent(obj_doc_lines[idx])
839
+ arguments = {}
840
+ current_arg = None
841
+ idx += 1
842
+ start_idx = idx
843
+ # Keep going until the arg section is finished (nonempty line at the same indent level) or the end of the docstring.
844
+ while idx < len(obj_doc_lines) and (
845
+ len(obj_doc_lines[idx].strip()) == 0 or find_indent(obj_doc_lines[idx]) > indent
846
+ ):
847
+ if find_indent(obj_doc_lines[idx]) == indent + 4:
848
+ # New argument -> let's generate the proper doc for it
849
+ re_search_arg = _re_parse_arg.search(obj_doc_lines[idx])
850
+ if re_search_arg is not None:
851
+ _, name, description = re_search_arg.groups()
852
+ current_arg = name
853
+ if name in signature:
854
+ default = signature[name].default
855
+ if signature[name].kind is inspect._ParameterKind.VAR_KEYWORD:
856
+ default = None
857
+ new_description = replace_default_in_arg_description(description, default)
858
+ else:
859
+ new_description = description
860
+ init_doc = _re_parse_arg.sub(rf"\1\2 ({new_description}):", obj_doc_lines[idx])
861
+ arguments[current_arg] = [init_doc]
862
+ elif current_arg is not None:
863
+ arguments[current_arg].append(obj_doc_lines[idx])
864
+
865
+ idx += 1
866
+
867
+ # We went too far by one (perhaps more if there are a lot of new lines)
868
+ idx -= 1
869
+ if current_arg:
870
+ while len(obj_doc_lines[idx].strip()) == 0:
871
+ arguments[current_arg] = arguments[current_arg][:-1]
872
+ idx -= 1
873
+ # And we went too far by one again.
874
+ idx += 1
875
+
876
+ old_doc_arg = "\n".join(obj_doc_lines[start_idx:idx])
877
+
878
+ old_arguments = list(arguments.keys())
879
+ arguments = {name: "\n".join(doc) for name, doc in arguments.items()}
880
+ # Add missing arguments with a template
881
+ for name in set(signature.keys()) - set(arguments.keys()):
882
+ arg = signature[name]
883
+ # We ignore private arguments or *args/**kwargs (unless they are documented by the user)
884
+ if name.startswith("_") or arg.kind in [
885
+ inspect._ParameterKind.VAR_KEYWORD,
886
+ inspect._ParameterKind.VAR_POSITIONAL,
887
+ ]:
888
+ arguments[name] = ""
889
+ else:
890
+ arg_desc = get_default_description(arg)
891
+ arguments[name] = " " * (indent + 4) + f"{name} ({arg_desc}): <fill_docstring>"
892
+
893
+ # Arguments are sorted by the order in the signature unless a special comment is put.
894
+ if ignore_order:
895
+ new_param_docs = [arguments[name] for name in old_arguments if name in signature]
896
+ missing = set(signature.keys()) - set(old_arguments)
897
+ new_param_docs.extend([arguments[name] for name in missing if len(arguments[name]) > 0])
898
+ else:
899
+ new_param_docs = [arguments[name] for name in signature.keys() if len(arguments[name]) > 0]
900
+ new_doc_arg = "\n".join(new_param_docs)
901
+
902
+ return old_doc_arg, new_doc_arg
903
+
904
+
905
+ def fix_docstring(obj: Any, old_doc_args: str, new_doc_args: str):
906
+ """
907
+ Fixes the docstring of an object by replacing its arguments documentation by the one matched with the signature.
908
+
909
+ Args:
910
+ obj (`Any`):
911
+ The object whose dostring we are fixing.
912
+ old_doc_args (`str`):
913
+ The current documentation of the parameters of `obj` in the docstring (as returned by
914
+ `match_docstring_with_signature`).
915
+ new_doc_args (`str`):
916
+ The documentation of the parameters of `obj` matched with its signature (as returned by
917
+ `match_docstring_with_signature`).
918
+ """
919
+ # Read the docstring in the source code and make sure we have the right part of the docstring
920
+ source, line_number = inspect.getsourcelines(obj)
921
+
922
+ # Get to the line where we start documenting arguments
923
+ idx = 0
924
+ while idx < len(source) and _re_args.search(source[idx]) is None:
925
+ idx += 1
926
+
927
+ if idx == len(source):
928
+ # Args are not defined in the docstring of this object
929
+ return
930
+
931
+ # Get to the line where we stop documenting arguments
932
+ indent = find_indent(source[idx])
933
+ idx += 1
934
+ start_idx = idx
935
+ while idx < len(source) and (len(source[idx].strip()) == 0 or find_indent(source[idx]) > indent):
936
+ idx += 1
937
+
938
+ idx -= 1
939
+ while len(source[idx].strip()) == 0:
940
+ idx -= 1
941
+ idx += 1
942
+
943
+ if "".join(source[start_idx:idx])[:-1] != old_doc_args:
944
+ # Args are not fully defined in the docstring of this object
945
+ return
946
+
947
+ obj_file = find_source_file(obj)
948
+ with open(obj_file, "r", encoding="utf-8") as f:
949
+ content = f.read()
950
+
951
+ # Replace content
952
+ lines = content.split("\n")
953
+ lines = lines[: line_number + start_idx - 1] + [new_doc_args] + lines[line_number + idx - 1 :]
954
+
955
+ print(f"Fixing the docstring of {obj.__name__} in {obj_file}.")
956
+ with open(obj_file, "w", encoding="utf-8") as f:
957
+ f.write("\n".join(lines))
958
+
959
+
960
+ def check_docstrings(overwrite: bool = False, check_all: bool = False):
961
+ """
962
+ Check docstrings of all public objects that are callables and are documented. By default, only checks the diff.
963
+
964
+ Args:
965
+ overwrite (`bool`, *optional*, defaults to `False`):
966
+ Whether to fix inconsistencies or not.
967
+ check_all (`bool`, *optional*, defaults to `False`):
968
+ Whether to check all files.
969
+ """
970
+ module_diff_files = None
971
+ if not check_all:
972
+ module_diff_files = set()
973
+ repo = Repo(PATH_TO_REPO)
974
+ # Diff from index to unstaged files
975
+ for modified_file_diff in repo.index.diff(None):
976
+ if modified_file_diff.a_path.startswith("src/transformers"):
977
+ module_diff_files.add(modified_file_diff.a_path)
978
+ # Diff from index to `main`
979
+ for modified_file_diff in repo.index.diff(repo.refs.main.commit):
980
+ if modified_file_diff.a_path.startswith("src/transformers"):
981
+ module_diff_files.add(modified_file_diff.a_path)
982
+ # quick escape route: if there are no module files in the diff, skip this check
983
+ if len(module_diff_files) == 0:
984
+ return
985
+ print(" Checking docstrings in the following files:" + "\n - " + "\n - ".join(module_diff_files))
986
+
987
+ failures = []
988
+ hard_failures = []
989
+ to_clean = []
990
+ for name in dir(transformers):
991
+ # Skip objects that are private or not documented.
992
+ if name.startswith("_") or ignore_undocumented(name) or name in OBJECTS_TO_IGNORE:
993
+ continue
994
+
995
+ obj = getattr(transformers, name)
996
+ if not callable(obj) or not isinstance(obj, type) or getattr(obj, "__doc__", None) is None:
997
+ continue
998
+
999
+ # If we are checking against the diff, we skip objects that are not part of the diff.
1000
+ if module_diff_files is not None:
1001
+ object_file = find_source_file(getattr(transformers, name))
1002
+ object_file_relative_path = "src/" + str(object_file).split("/src/")[1]
1003
+ if object_file_relative_path not in module_diff_files:
1004
+ continue
1005
+
1006
+ # Check docstring
1007
+ try:
1008
+ result = match_docstring_with_signature(obj)
1009
+ if result is not None:
1010
+ old_doc, new_doc = result
1011
+ else:
1012
+ old_doc, new_doc = None, None
1013
+ except Exception as e:
1014
+ print(e)
1015
+ hard_failures.append(name)
1016
+ continue
1017
+ if old_doc != new_doc:
1018
+ if overwrite:
1019
+ fix_docstring(obj, old_doc, new_doc)
1020
+ else:
1021
+ failures.append(name)
1022
+ elif not overwrite and new_doc is not None and ("<fill_type>" in new_doc or "<fill_docstring>" in new_doc):
1023
+ to_clean.append(name)
1024
+
1025
+ # Deal with errors
1026
+ error_message = ""
1027
+ if len(hard_failures) > 0:
1028
+ error_message += (
1029
+ "The argument part of the docstrings of the following objects could not be processed, check they are "
1030
+ "properly formatted."
1031
+ )
1032
+ error_message += "\n" + "\n".join([f"- {name}" for name in hard_failures])
1033
+ if len(failures) > 0:
1034
+ error_message += (
1035
+ "The following objects docstrings do not match their signature. Run `make fix-copies` to fix this. "
1036
+ "In some cases, this error may be raised incorrectly by the docstring checker. If you think this is the "
1037
+ "case, you can manually check the docstrings and then add the object name to `OBJECTS_TO_IGNORE` in "
1038
+ "`utils/check_docstrings.py`."
1039
+ )
1040
+ error_message += "\n" + "\n".join([f"- {name}" for name in failures])
1041
+ if len(to_clean) > 0:
1042
+ error_message += (
1043
+ "The following objects docstrings contain templates you need to fix: search for `<fill_type>` or "
1044
+ "`<fill_docstring>`."
1045
+ )
1046
+ error_message += "\n" + "\n".join([f"- {name}" for name in to_clean])
1047
+
1048
+ if len(error_message) > 0:
1049
+ error_message = "There was at least one problem when checking docstrings of public objects.\n" + error_message
1050
+ raise ValueError(error_message)
1051
+
1052
+
1053
+ if __name__ == "__main__":
1054
+ parser = argparse.ArgumentParser()
1055
+ parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
1056
+ parser.add_argument(
1057
+ "--check_all", action="store_true", help="Whether to check all files. By default, only checks the diff"
1058
+ )
1059
+ args = parser.parse_args()
1060
+
1061
+ check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)