| | import ast |
| | import re |
| | from collections import defaultdict |
| |
|
| |
|
| | |
| | def topological_sort(dependencies: dict) -> list[list[str]]: |
| | """Given the dependencies graph, construct a sorted list of list of modular files. |
| | |
| | Examples: |
| | |
| | The returned list of lists might be: |
| | [ |
| | ["../modular_mistral.py", "../modular_gemma.py"], # level 0 |
| | ["../modular_llama4.py", "../modular_gemma2.py"], # level 1 |
| | ["../modular_glm4.py"], # level 2 |
| | ] |
| | which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2 |
| | depend on the models in the first list, and glm4 depends on the models in the second and (optionally) in the first list. |
| | """ |
| |
|
| | |
| | nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies} |
| | |
| | graph = {} |
| | name_mapping = {} |
| | for node, deps in dependencies.items(): |
| | node_name = node.rsplit("modular_", 1)[1].replace(".py", "") |
| | dep_names = {dep.split(".")[-2] for dep in deps} |
| | dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name} |
| | graph[node_name] = dependencies |
| | name_mapping[node_name] = node |
| |
|
| | sorting_list = [] |
| | while len(graph) > 0: |
| | |
| | leaf_nodes = {node for node in graph if len(graph[node]) == 0} |
| | |
| | sorting_list.append([name_mapping[node] for node in leaf_nodes]) |
| | |
| | graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes} |
| |
|
| | return sorting_list |
| |
|
| |
|
| | |
| | ALL_FILE_TYPES = ( |
| | "modeling", |
| | "configuration", |
| | "tokenization", |
| | "processing", |
| | "image_processing", |
| | "video_processing", |
| | "feature_extraction", |
| | ) |
| |
|
| |
|
| | def is_model_import(module: str | None) -> bool: |
| | """Check whether `module` is a model import or not.""" |
| | |
| | if module is None: |
| | return False |
| | patterns = "|".join(ALL_FILE_TYPES) |
| | regex = rf"(\w+)\.(?:{patterns})_(\w+)" |
| | match_object = re.search(regex, module) |
| | if match_object is not None: |
| | model_name = match_object.group(1) |
| | if model_name in match_object.group(2) and model_name != "auto": |
| | return True |
| | return False |
| |
|
| |
|
| | def extract_model_imports_from_file(file_path): |
| | """From a python file `file_path`, extract the model-specific imports (the imports related to any model file in |
| | Transformers)""" |
| | with open(file_path, "r", encoding="utf-8") as file: |
| | tree = ast.parse(file.read(), filename=file_path) |
| | imports = set() |
| |
|
| | for node in ast.walk(tree): |
| | if isinstance(node, ast.ImportFrom): |
| | if is_model_import(node.module): |
| | imports.add(node.module) |
| | return imports |
| |
|
| |
|
| | def find_priority_list(modular_files: list[str]) -> tuple[list[list[str]], dict[str, set]]: |
| | """ |
| | Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular |
| | models will be lower in the topological order. |
| | |
| | Args: |
| | modular_files (`list[str]`): |
| | List of paths to the modular files. |
| | |
| | Returns: |
| | A tuple `ordered_files` and `dependencies`. |
| | |
| | `ordered_file` is a list of lists consisting of the models at each level of the dependency graph. For example, |
| | it might be: |
| | [ |
| | ["../modular_mistral.py", "../modular_gemma.py"], # level 0 |
| | ["../modular_llama4.py", "../modular_gemma2.py"], # level 1 |
| | ["../modular_glm4.py"], # level 2 |
| | ] |
| | which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2 depend on the |
| | models in the first list, and glm4 depends on the models in the second and (optionally) in the first list. |
| | |
| | `dependencies` is a dictionary mapping each modular file to the models on which it relies (the models that are |
| | imported in order to use inheritance). |
| | """ |
| | dependencies = defaultdict(set) |
| | for file_path in modular_files: |
| | dependencies[file_path].update(extract_model_imports_from_file(file_path)) |
| | ordered_files = topological_sort(dependencies) |
| | return ordered_files, dependencies |
| |
|