| import re |
|
|
| from transformers.pipelines import SUPPORTED_TASKS, Pipeline |
|
|
|
|
| HEADER = """ |
| # fmt: off |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| # The part of the file below was automatically generated from the code. |
| # Do NOT edit this part of the file manually as any edits will be overwritten by the generation |
| # of the file. If any change should be done, please apply the changes to the `pipeline` function |
| # below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file. |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| |
| from typing import Literal, overload |
| |
| |
| """ |
|
|
| FOOTER = """ |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| # The part of the file above was automatically generated from the code. |
| # π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨ |
| # fmt: on |
| """ |
|
|
| TASK_PATTERN = "task: Optional[str] = None" |
|
|
|
|
| def main(pipeline_file_path: str, fix_and_overwrite: bool = False): |
| with open(pipeline_file_path, "r") as file: |
| content = file.read() |
|
|
| |
| current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1) |
| content_without_generated_code = content.replace(current_generated_code, "") |
|
|
| |
| pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group( |
| 1 |
| ) |
| pipeline_signature = pipeline_signature.replace("(\n ", "(") |
| pipeline_signature = pipeline_signature.replace(",\n ", ", ") |
| pipeline_signature = pipeline_signature.replace(",\n)", ")") |
|
|
| |
| pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()] |
| pipelines = sorted(pipelines, key=lambda x: x[0]) |
| pipelines.insert(0, (None, Pipeline)) |
|
|
| |
| new_generated_code = "" |
| for task, pipeline_class in pipelines: |
| if TASK_PATTERN not in pipeline_signature: |
| raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}") |
| pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__ |
| new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]") |
| new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n" |
|
|
| new_generated_code = HEADER + new_generated_code + FOOTER |
| new_generated_code = new_generated_code.rstrip("\n") + "\n" |
|
|
| if new_generated_code != current_generated_code and fix_and_overwrite: |
| print(f"Updating {pipeline_file_path}...") |
| wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>" |
| wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>" |
| content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code) |
|
|
| |
| with open(pipeline_file_path, "w") as file: |
| file.write(content) |
|
|
| elif new_generated_code != current_generated_code and not fix_and_overwrite: |
| message = ( |
| f"Found inconsistencies in {pipeline_file_path}. " |
| "Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them." |
| ) |
| raise ValueError(message) |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") |
| parser.add_argument( |
| "--pipeline_file_path", |
| type=str, |
| default="src/transformers/pipelines/__init__.py", |
| help="Path to the pipeline file.", |
| ) |
| args = parser.parse_args() |
| main(args.pipeline_file_path, args.fix_and_overwrite) |
|
|