| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import os |
| import shutil |
| import warnings |
| from argparse import ArgumentParser, Namespace |
| from pathlib import Path |
| from typing import List |
|
|
| from ..utils import logging |
| from . import BaseTransformersCLICommand |
|
|
|
|
| try: |
| from cookiecutter.main import cookiecutter |
|
|
| _has_cookiecutter = True |
| except ImportError: |
| _has_cookiecutter = False |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| def add_new_model_command_factory(args: Namespace): |
| return AddNewModelCommand(args.testing, args.testing_file, path=args.path) |
|
|
|
|
| class AddNewModelCommand(BaseTransformersCLICommand): |
| @staticmethod |
| def register_subcommand(parser: ArgumentParser): |
| add_new_model_parser = parser.add_parser("add-new-model") |
| add_new_model_parser.add_argument("--testing", action="store_true", help="If in testing mode.") |
| add_new_model_parser.add_argument("--testing_file", type=str, help="Configuration file on which to run.") |
| add_new_model_parser.add_argument( |
| "--path", type=str, help="Path to cookiecutter. Should only be used for testing purposes." |
| ) |
| add_new_model_parser.set_defaults(func=add_new_model_command_factory) |
|
|
| def __init__(self, testing: bool, testing_file: str, path=None, *args): |
| self._testing = testing |
| self._testing_file = testing_file |
| self._path = path |
|
|
| def run(self): |
| warnings.warn( |
| "The command `transformers-cli add-new-model` is deprecated and will be removed in v5 of Transformers. " |
| "It is not actively maintained anymore, so might give a result that won't pass all tests and quality " |
| "checks, you should use `transformers-cli add-new-model-like` instead." |
| ) |
| if not _has_cookiecutter: |
| raise ImportError( |
| "Model creation dependencies are required to use the `add_new_model` command. Install them by running " |
| "the following at the root of your `transformers` clone:\n\n\t$ pip install -e .[modelcreation]\n" |
| ) |
| |
| directories = [directory for directory in os.listdir() if "cookiecutter-template-" == directory[:22]] |
| if len(directories) > 0: |
| raise ValueError( |
| "Several directories starting with `cookiecutter-template-` in current working directory. " |
| "Please clean your directory by removing all folders starting with `cookiecutter-template-` or " |
| "change your working directory." |
| ) |
|
|
| path_to_transformer_root = ( |
| Path(__file__).parent.parent.parent.parent if self._path is None else Path(self._path).parent.parent |
| ) |
| path_to_cookiecutter = path_to_transformer_root / "templates" / "adding_a_new_model" |
|
|
| |
| if not self._testing: |
| cookiecutter(str(path_to_cookiecutter)) |
| else: |
| with open(self._testing_file, "r") as configuration_file: |
| testing_configuration = json.load(configuration_file) |
|
|
| cookiecutter( |
| str(path_to_cookiecutter if self._path is None else self._path), |
| no_input=True, |
| extra_context=testing_configuration, |
| ) |
|
|
| directory = [directory for directory in os.listdir() if "cookiecutter-template-" in directory[:22]][0] |
|
|
| |
| with open(directory + "/configuration.json", "r") as configuration_file: |
| configuration = json.load(configuration_file) |
|
|
| lowercase_model_name = configuration["lowercase_modelname"] |
| generate_tensorflow_pytorch_and_flax = configuration["generate_tensorflow_pytorch_and_flax"] |
| os.remove(f"{directory}/configuration.json") |
|
|
| output_pytorch = "PyTorch" in generate_tensorflow_pytorch_and_flax |
| output_tensorflow = "TensorFlow" in generate_tensorflow_pytorch_and_flax |
| output_flax = "Flax" in generate_tensorflow_pytorch_and_flax |
|
|
| model_dir = f"{path_to_transformer_root}/src/transformers/models/{lowercase_model_name}" |
| os.makedirs(model_dir, exist_ok=True) |
| os.makedirs(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}", exist_ok=True) |
|
|
| |
| with open(f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/__init__.py", "w"): |
| pass |
|
|
| shutil.move( |
| f"{directory}/__init__.py", |
| f"{model_dir}/__init__.py", |
| ) |
| shutil.move( |
| f"{directory}/configuration_{lowercase_model_name}.py", |
| f"{model_dir}/configuration_{lowercase_model_name}.py", |
| ) |
|
|
| def remove_copy_lines(path): |
| with open(path, "r") as f: |
| lines = f.readlines() |
| with open(path, "w") as f: |
| for line in lines: |
| if "# Copied from transformers." not in line: |
| f.write(line) |
|
|
| if output_pytorch: |
| if not self._testing: |
| remove_copy_lines(f"{directory}/modeling_{lowercase_model_name}.py") |
|
|
| shutil.move( |
| f"{directory}/modeling_{lowercase_model_name}.py", |
| f"{model_dir}/modeling_{lowercase_model_name}.py", |
| ) |
|
|
| shutil.move( |
| f"{directory}/test_modeling_{lowercase_model_name}.py", |
| f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_{lowercase_model_name}.py", |
| ) |
| else: |
| os.remove(f"{directory}/modeling_{lowercase_model_name}.py") |
| os.remove(f"{directory}/test_modeling_{lowercase_model_name}.py") |
|
|
| if output_tensorflow: |
| if not self._testing: |
| remove_copy_lines(f"{directory}/modeling_tf_{lowercase_model_name}.py") |
|
|
| shutil.move( |
| f"{directory}/modeling_tf_{lowercase_model_name}.py", |
| f"{model_dir}/modeling_tf_{lowercase_model_name}.py", |
| ) |
|
|
| shutil.move( |
| f"{directory}/test_modeling_tf_{lowercase_model_name}.py", |
| f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_tf_{lowercase_model_name}.py", |
| ) |
| else: |
| os.remove(f"{directory}/modeling_tf_{lowercase_model_name}.py") |
| os.remove(f"{directory}/test_modeling_tf_{lowercase_model_name}.py") |
|
|
| if output_flax: |
| if not self._testing: |
| remove_copy_lines(f"{directory}/modeling_flax_{lowercase_model_name}.py") |
|
|
| shutil.move( |
| f"{directory}/modeling_flax_{lowercase_model_name}.py", |
| f"{model_dir}/modeling_flax_{lowercase_model_name}.py", |
| ) |
|
|
| shutil.move( |
| f"{directory}/test_modeling_flax_{lowercase_model_name}.py", |
| f"{path_to_transformer_root}/tests/models/{lowercase_model_name}/test_modeling_flax_{lowercase_model_name}.py", |
| ) |
| else: |
| os.remove(f"{directory}/modeling_flax_{lowercase_model_name}.py") |
| os.remove(f"{directory}/test_modeling_flax_{lowercase_model_name}.py") |
|
|
| shutil.move( |
| f"{directory}/{lowercase_model_name}.md", |
| f"{path_to_transformer_root}/docs/source/en/model_doc/{lowercase_model_name}.md", |
| ) |
|
|
| shutil.move( |
| f"{directory}/tokenization_{lowercase_model_name}.py", |
| f"{model_dir}/tokenization_{lowercase_model_name}.py", |
| ) |
|
|
| shutil.move( |
| f"{directory}/tokenization_fast_{lowercase_model_name}.py", |
| f"{model_dir}/tokenization_{lowercase_model_name}_fast.py", |
| ) |
|
|
| from os import fdopen, remove |
| from shutil import copymode, move |
| from tempfile import mkstemp |
|
|
| def replace(original_file: str, line_to_copy_below: str, lines_to_copy: List[str]): |
| |
| fh, abs_path = mkstemp() |
| line_found = False |
| with fdopen(fh, "w") as new_file: |
| with open(original_file) as old_file: |
| for line in old_file: |
| new_file.write(line) |
| if line_to_copy_below in line: |
| line_found = True |
| for line_to_copy in lines_to_copy: |
| new_file.write(line_to_copy) |
|
|
| if not line_found: |
| raise ValueError(f"Line {line_to_copy_below} was not found in file.") |
|
|
| |
| copymode(original_file, abs_path) |
| |
| remove(original_file) |
| |
| move(abs_path, original_file) |
|
|
| def skip_units(line): |
| return ( |
| ("generating PyTorch" in line and not output_pytorch) |
| or ("generating TensorFlow" in line and not output_tensorflow) |
| or ("generating Flax" in line and not output_flax) |
| ) |
|
|
| def replace_in_files(path_to_datafile): |
| with open(path_to_datafile) as datafile: |
| lines_to_copy = [] |
| skip_file = False |
| skip_snippet = False |
| for line in datafile: |
| if "# To replace in: " in line and "##" not in line: |
| file_to_replace_in = line.split('"')[1] |
| skip_file = skip_units(line) |
| elif "# Below: " in line and "##" not in line: |
| line_to_copy_below = line.split('"')[1] |
| skip_snippet = skip_units(line) |
| elif "# End." in line and "##" not in line: |
| if not skip_file and not skip_snippet: |
| replace(file_to_replace_in, line_to_copy_below, lines_to_copy) |
|
|
| lines_to_copy = [] |
| elif "# Replace with" in line and "##" not in line: |
| lines_to_copy = [] |
| elif "##" not in line: |
| lines_to_copy.append(line) |
|
|
| remove(path_to_datafile) |
|
|
| replace_in_files(f"{directory}/to_replace_{lowercase_model_name}.py") |
| os.rmdir(directory) |
|
|