| | import os |
| | import unittest |
| | from collections.abc import Callable |
| | from pathlib import Path |
| |
|
| | import pytest |
| |
|
| | from transformers.utils.import_utils import ( |
| | Backend, |
| | VersionComparison, |
| | define_import_structure, |
| | spread_import_structure, |
| | ) |
| |
|
| |
|
| | import_structures = Path(__file__).parent / "import_structures" |
| |
|
| |
|
| | def fetch__all__(file_content): |
| | """ |
| | Returns the content of the __all__ variable in the file content. |
| | Returns None if not defined, otherwise returns a list of strings. |
| | """ |
| | lines = file_content.split("\n") |
| | for line_index in range(len(lines)): |
| | line = lines[line_index] |
| | if line.startswith("__all__ = "): |
| | |
| | if line.endswith("]"): |
| | return [obj.strip("\"' ") for obj in line.split("=")[1].strip(" []").split(",")] |
| |
|
| | |
| | else: |
| | _all = [] |
| | for __all__line_index in range(line_index + 1, len(lines)): |
| | if lines[__all__line_index].strip() == "]": |
| | return _all |
| | else: |
| | _all.append(lines[__all__line_index].strip("\"', ")) |
| |
|
| |
|
| | class TestImportStructures(unittest.TestCase): |
| | base_transformers_path = Path(__file__).parent.parent.parent |
| | models_path = base_transformers_path / "src" / "transformers" / "models" |
| | models_import_structure = spread_import_structure(define_import_structure(models_path)) |
| |
|
| | def test_definition(self): |
| | import_structure = define_import_structure(import_structures) |
| | valid_frozensets: dict[frozenset | frozenset[str], dict[str, set[str]]] = { |
| | frozenset(): { |
| | "import_structure_raw_register": {"A0", "A4", "a0"}, |
| | "import_structure_register_with_comments": {"B0", "b0"}, |
| | }, |
| | frozenset({"random_item_that_should_not_exist"}): {"failing_export": {"A0"}}, |
| | frozenset({"torch"}): { |
| | "import_structure_raw_register": {"A1", "A2", "A3", "a1", "a2", "a3"}, |
| | "import_structure_register_with_duplicates": {"C0", "C1", "C2", "C3", "c0", "c1", "c2", "c3"}, |
| | "import_structure_register_with_comments": {"B1", "B2", "B3", "b1", "b2", "b3"}, |
| | }, |
| | frozenset({"torch>=2.5"}): {"import_structure_raw_register_with_versions": {"D0", "d0"}}, |
| | frozenset({"torch>2.5"}): {"import_structure_raw_register_with_versions": {"D1", "d1"}}, |
| | frozenset({"torch<=2.5"}): {"import_structure_raw_register_with_versions": {"D2", "d2"}}, |
| | frozenset({"torch<2.5"}): {"import_structure_raw_register_with_versions": {"D3", "d3"}}, |
| | frozenset({"torch==2.5"}): {"import_structure_raw_register_with_versions": {"D4", "d4"}}, |
| | frozenset({"torch!=2.5"}): {"import_structure_raw_register_with_versions": {"D5", "d5"}}, |
| | frozenset({"torch>=2.5", "accelerate<0.20"}): { |
| | "import_structure_raw_register_with_versions": {"D6", "d6"} |
| | }, |
| | } |
| |
|
| | self.assertEqual(len(import_structure.keys()), len(valid_frozensets.keys())) |
| | for _frozenset in valid_frozensets: |
| | self.assertTrue(_frozenset in import_structure) |
| | self.assertListEqual( |
| | sorted(import_structure[_frozenset].keys()), sorted(valid_frozensets[_frozenset].keys()) |
| | ) |
| | for module, objects in valid_frozensets[_frozenset].items(): |
| | self.assertTrue(module in import_structure[_frozenset]) |
| | self.assertSetEqual(objects, import_structure[_frozenset][module]) |
| |
|
| | def test_transformers_specific_model_import(self): |
| | """ |
| | This test ensures that there is equivalence between what is written down in __all__ and what is |
| | written down with register(). |
| | |
| | It doesn't test the backends attributed to register(). |
| | """ |
| | for architecture in os.listdir(self.models_path): |
| | if ( |
| | os.path.isfile(self.models_path / architecture) |
| | or architecture.startswith("_") |
| | or architecture == "deprecated" |
| | ): |
| | continue |
| |
|
| | with self.subTest(f"Testing arch {architecture}"): |
| | import_structure = define_import_structure(self.models_path / architecture) |
| | backend_agnostic_import_structure = {} |
| | for module_object_mapping in import_structure.values(): |
| | for module, objects in module_object_mapping.items(): |
| | if module not in backend_agnostic_import_structure: |
| | backend_agnostic_import_structure[module] = [] |
| |
|
| | backend_agnostic_import_structure[module].extend(objects) |
| |
|
| | for module, objects in backend_agnostic_import_structure.items(): |
| | with open(self.models_path / architecture / f"{module}.py") as f: |
| | content = f.read() |
| | _all = fetch__all__(content) |
| |
|
| | if _all is None: |
| | raise ValueError(f"{module} doesn't have __all__ defined.") |
| |
|
| | error_message = ( |
| | f"self.models_path / architecture / f'{module}.py doesn't seem to be defined correctly:\n" |
| | f"Defined in __all__: {sorted(_all)}\nDefined with register: {sorted(objects)}" |
| | ) |
| | self.assertListEqual(sorted(objects), sorted(_all), msg=error_message) |
| |
|
| | def test_import_spread(self): |
| | """ |
| | This test is specifically designed to test that varying levels of depth across import structures are |
| | respected. |
| | |
| | In this instance, frozensets are at respective depths of 1, 2 and 3, for example: |
| | - models.{frozensets} |
| | - models.albert.{frozensets} |
| | - models.deprecated.transfo_xl.{frozensets} |
| | """ |
| | initial_import_structure = { |
| | frozenset(): {"dummy_non_model": {"DummyObject"}}, |
| | "models": { |
| | frozenset(): {"dummy_config": {"DummyConfig"}}, |
| | "albert": { |
| | frozenset(): {"configuration_albert": {"AlbertConfig"}}, |
| | frozenset({"torch"}): { |
| | "modeling_albert": { |
| | "AlbertForMaskedLM", |
| | } |
| | }, |
| | }, |
| | "llama": { |
| | frozenset(): {"configuration_llama": {"LlamaConfig"}}, |
| | frozenset({"torch"}): { |
| | "modeling_llama": { |
| | "LlamaForCausalLM", |
| | } |
| | }, |
| | }, |
| | "deprecated": { |
| | "transfo_xl": { |
| | frozenset({"torch"}): { |
| | "modeling_transfo_xl": { |
| | "TransfoXLModel", |
| | } |
| | }, |
| | frozenset(): { |
| | "configuration_transfo_xl": {"TransfoXLConfig"}, |
| | "tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"}, |
| | }, |
| | }, |
| | "deta": { |
| | frozenset({"torch"}): { |
| | "modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"} |
| | }, |
| | frozenset(): {"configuration_deta": {"DetaConfig"}}, |
| | frozenset({"vision"}): {"image_processing_deta": {"DetaImageProcessor"}}, |
| | }, |
| | }, |
| | }, |
| | } |
| |
|
| | ground_truth_spread_import_structure = { |
| | frozenset(): { |
| | "dummy_non_model": {"DummyObject"}, |
| | "models.dummy_config": {"DummyConfig"}, |
| | "models.albert.configuration_albert": {"AlbertConfig"}, |
| | "models.llama.configuration_llama": {"LlamaConfig"}, |
| | "models.deprecated.transfo_xl.configuration_transfo_xl": {"TransfoXLConfig"}, |
| | "models.deprecated.transfo_xl.tokenization_transfo_xl": {"TransfoXLCorpus", "TransfoXLTokenizer"}, |
| | "models.deprecated.deta.configuration_deta": {"DetaConfig"}, |
| | }, |
| | frozenset({"torch"}): { |
| | "models.albert.modeling_albert": {"AlbertForMaskedLM"}, |
| | "models.llama.modeling_llama": {"LlamaForCausalLM"}, |
| | "models.deprecated.transfo_xl.modeling_transfo_xl": {"TransfoXLModel"}, |
| | "models.deprecated.deta.modeling_deta": {"DetaForObjectDetection", "DetaModel", "DetaPreTrainedModel"}, |
| | }, |
| | frozenset({"vision"}): {"models.deprecated.deta.image_processing_deta": {"DetaImageProcessor"}}, |
| | } |
| |
|
| | newly_spread_import_structure = spread_import_structure(initial_import_structure) |
| |
|
| | self.assertEqual(ground_truth_spread_import_structure, newly_spread_import_structure) |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "backend,package_name,version_comparison,version", |
| | [ |
| | pytest.param(Backend("torch>=2.5 "), "torch", VersionComparison.GREATER_THAN_OR_EQUAL.value, "2.5"), |
| | pytest.param(Backend("torchvision==0.19.1"), "torchvision", VersionComparison.EQUAL.value, "0.19.1"), |
| | ], |
| | ) |
| | def test_backend_specification(backend: Backend, package_name: str, version_comparison: Callable, version: str): |
| | assert backend.package_name == package_name |
| | assert VersionComparison.from_string(backend.version_comparison) == version_comparison |
| | assert backend.version == version |
| |
|