| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import argparse |
| | import glob |
| | import importlib |
| | import multiprocessing as mp |
| | import os |
| | import re |
| | import subprocess |
| | from abc import ABC, abstractmethod |
| | from collections import Counter, defaultdict, deque |
| | from functools import partial |
| |
|
| | import libcst as cst |
| | from create_dependency_mapping import find_priority_list |
| | from libcst import ClassDef, CSTVisitor |
| | from libcst import matchers as m |
| | from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider |
| | from modular_integrations import ( |
| | EXCLUDED_EXTERNAL_FILES, |
| | AbsoluteImportTransformer, |
| | RelativeImportTransformer, |
| | convert_relative_import_to_absolute, |
| | ) |
| |
|
| | from transformers import logging |
| | from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 |
| | # This file was automatically generated from {relative_path}. |
| | # Do NOT edit this file manually as any edits will be overwritten by the generation of |
| | # the file from the modular. If any change should be done, please apply the change to the |
| | # {short_name} file directly. One of our CI enforces this. |
| | # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 |
| | """ |
| |
|
| |
|
| | def get_module_source_from_name(module_name: str) -> str: |
| | |
| | spec = importlib.util.find_spec(module_name) |
| | if spec is None or spec.origin is None: |
| | raise ValueError(f"Cannot open file associated with {module_name} module.") |
| |
|
| | with open(spec.origin, "r", encoding="utf-8") as file: |
| | source_code = file.read() |
| | return source_code |
| |
|
| |
|
| | def preserve_case_replace(text, patterns: dict, default_name: str): |
| | |
| | regex_pattern = "|".join(re.escape(key) for key in patterns) |
| | compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL) |
| |
|
| | def replace(match): |
| | matched_pattern = match.group(1) |
| | next_char = match.group(2) |
| | new_pattern = patterns.get(matched_pattern, default_name) |
| |
|
| | |
| | |
| | if len(patterns) == 2 and matched_pattern.isupper(): |
| | if not next_char.isalpha(): |
| | |
| | new_pattern = patterns[matched_pattern.lower()].upper() |
| |
|
| | return new_pattern + next_char |
| |
|
| | return compiled_regex.sub(replace, text) |
| |
|
| |
|
| | def get_cased_name(lowercase_name: str) -> str: |
| | """From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`.""" |
| | alt_lowercase_name = lowercase_name.replace("_", "-") |
| | if lowercase_name in CONFIG_MAPPING_NAMES: |
| | return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "") |
| | elif alt_lowercase_name in CONFIG_MAPPING_NAMES: |
| | return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "") |
| | else: |
| | return "".join(x.title() for x in lowercase_name.split("_")) |
| |
|
| |
|
| | def get_lowercase_name(cased_name: str) -> str: |
| | """From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`.""" |
| | inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()} |
| | if cased_name + "Config" in inverse_mapping: |
| | return inverse_mapping[cased_name + "Config"] |
| | else: |
| | return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)]) |
| |
|
| |
|
| | class ReplaceNameTransformer(m.MatcherDecoratableTransformer): |
| | """A transformer that replaces `old_name` with `new_name` in comments, string and any references. |
| | It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. |
| | Supported renaming patterns: |
| | - llama -> my_new_model and my_new_model -> llama |
| | - Llama -> MyNewModel and MyNewModel -> Llama |
| | - LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA |
| | - LLaMa -> MyNewModel and MyNewModel -> Llama |
| | """ |
| |
|
| | def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False): |
| | super().__init__() |
| | old_name = old_name.replace("-", "_") |
| | new_name = new_name.replace("-", "_") |
| | self.old_name = old_name |
| | self.new_name = new_name |
| | self.cased_new_name = get_cased_name(self.new_name) |
| | self.cased_old_name = get_cased_name(self.old_name) |
| | self.patterns = { |
| | old_name: new_name, |
| | old_name.upper(): new_name.upper(), |
| | |
| | self.cased_old_name: self.cased_new_name, |
| | } |
| | |
| | self.original_new_model_name = original_new_model_name |
| | self.only_doc = only_doc |
| |
|
| | def _replace_name(self, original_node, updated_node): |
| | if re.findall(r"# Copied from", updated_node.value): |
| | return cst.RemoveFromParent() |
| | update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name) |
| | return updated_node.with_changes(value=update) |
| |
|
| | @m.leave(m.SimpleString() | m.Comment()) |
| | def replace_name(self, original_node, updated_node): |
| | return self._replace_name(original_node, updated_node) |
| |
|
| | def leave_Name(self, original_node, updated_node): |
| | if not self.only_doc: |
| | return self._replace_name(original_node, updated_node) |
| | return updated_node |
| |
|
| | def leave_ImportFrom(self, original_node, updated_node): |
| | """ |
| | The imports from other file types (configuration, processing etc) should use original model name. |
| | Also, no replaces on absolute imports (e.g. `from mamba_ssm import ...`) |
| | """ |
| | if len(original_node.relative) == 0: |
| | return original_node |
| | if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()): |
| | patterns = "|".join(ALL_FILE_TYPES) |
| | regex = rf"({patterns})_{self.new_name}" |
| | new_source = re.sub( |
| | regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value |
| | ) |
| | updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source)) |
| | return updated_node |
| |
|
| |
|
| | DOCSTRING_NODE = m.SimpleStatementLine( |
| | body=[ |
| | m.Expr( |
| | value=m.SimpleString( |
| | |
| | value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) |
| | ) |
| | ) |
| | ] |
| | ) |
| |
|
| |
|
| | def get_full_attribute_name(node: cst.Attribute | cst.Name) -> str | None: |
| | """Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the |
| | successive value of an Attribute are not Name nodes, return `None`.""" |
| | if m.matches(node, m.Name()): |
| | return node.value |
| | elif m.matches(node, m.Attribute()): |
| | if not m.matches(node.attr, m.Name()): |
| | return None |
| | name = node.attr.value |
| | new_node = node.value |
| | while m.matches(new_node, m.Attribute()): |
| | if not m.matches(new_node.attr, m.Name()): |
| | return None |
| | name = new_node.attr.value + "." + name |
| | new_node = new_node.value |
| | if not m.matches(new_node, m.Name()): |
| | return None |
| | return new_node.value + "." + name |
| | return None |
| |
|
| |
|
| | class ReplaceParentClassCallTransformer(cst.CSTTransformer): |
| | """ |
| | This Transformer is used to replace all calls of the form `module.Class.func(...)` by a call of the form |
| | `super().func(...)`. |
| | """ |
| |
|
| | def __init__(self, new_bases: list[str]): |
| | self.new_bases = new_bases |
| |
|
| | def is_call_to_parent_class(self, node: cst.SimpleStatementLine): |
| | """Check whether `node` corresponds to a call to a parent class function, such as `module.Parent.func_name(...)`""" |
| | return m.matches(node, m.Call(func=m.Attribute(value=m.Name() | m.Attribute()))) |
| |
|
| | def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: |
| | """Replace a call of the form `module.Class.func(...)` by a call of the form `super().func(...)` |
| | if the `Class` being called is one of the bases.""" |
| | if self.is_call_to_parent_class(updated_node): |
| | full_parent_class_name = get_full_attribute_name(updated_node.func.value) |
| | |
| | if ( |
| | full_parent_class_name in self.new_bases |
| | or (full_parent_class_name == "nn.Module" and "GradientCheckpointingLayer" in self.new_bases) |
| | or ( |
| | full_parent_class_name == "PreTrainedModel" |
| | and any("PreTrainedModel" in base for base in self.new_bases) |
| | ) |
| | ): |
| | |
| | attribute_node = updated_node.func.with_changes(value=cst.Call(func=cst.Name("super"))) |
| | |
| | new_args = ( |
| | updated_node.args[1:] |
| | if len(updated_node.args) > 0 and m.matches(updated_node.args[0].value, m.Name("self")) |
| | else updated_node.args |
| | ) |
| | return updated_node.with_changes(func=attribute_node, args=new_args) |
| | return updated_node |
| |
|
| |
|
| | class ReplaceSuperCallTransformer(cst.CSTTransformer): |
| | """ |
| | This Transformer is used to unravel all calls to `super().func(...)` in class methods by the explicit parent's |
| | code. It will also in turn replace all calls of the form `module.Class.func(...)` by a call of the form |
| | `super().func(...)`. Those calls are used to explicitly skip the unravelling of code, but we should still follow |
| | python's standards and use `super().func(...)` instead of `Parent.func(self, ...)`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | python_module: cst.Module, |
| | original_modeling_methods: dict[str, cst.FunctionDef], |
| | modular_methods: dict[str, cst.FunctionDef], |
| | new_bases: list[cst.Arg], |
| | ): |
| | self.python_module = python_module |
| | self.original_modeling_methods = original_modeling_methods |
| | self.modular_methods = modular_methods |
| | self.all_assign_target = {} |
| | self.deleted_targets = {} |
| | new_bases = [get_full_attribute_name(base.value) for base in new_bases] |
| | self.parent_class_call_transformer = ReplaceParentClassCallTransformer(new_bases) |
| |
|
| | def update_body(self, existing_body, new_statements): |
| | """ |
| | Helper method to update the body by removing duplicates before adding new statements. |
| | `existing_body` is the body of the original method, the parent class |
| | `new_statements` are the additional statements |
| | """ |
| | deduplicated_new_body = [] |
| | existing_nodes = set() |
| | for node in new_statements: |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): |
| | target = self.python_module.code_for_node(node.body[0].targets[0].target) |
| | self.all_assign_target[target] = node |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
| | target = self.python_module.code_for_node(node.body[0].target) |
| | self.deleted_targets[target] = node |
| |
|
| | for stmt in existing_body: |
| | if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): |
| | target = self.python_module.code_for_node(stmt.body[0].targets[0].target) |
| | if target in self.deleted_targets: |
| | continue |
| | if target in self.all_assign_target: |
| | stmt = self.all_assign_target[target] |
| | |
| | elif m.matches(stmt, DOCSTRING_NODE): |
| | continue |
| | comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | deduplicated_new_body.append(stmt) |
| | existing_nodes.add(comment_less_code) |
| |
|
| | for node in new_statements: |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if node not in deduplicated_new_body and comment_less_code not in existing_nodes: |
| | if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
| | deduplicated_new_body.append(node) |
| | existing_nodes.add(comment_less_code) |
| |
|
| | deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) |
| |
|
| | return deduplicated_new_body |
| |
|
| | def _fix_post_init_location(self, new_body: list[cst.CSTNode]): |
| | """Fix the location of the `post_init()` in the new body, if we added statements after the call to |
| | `super()` (it needs to be the very last statement called)""" |
| | |
| | for i, node in enumerate(new_body): |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if "self.post_init(" in comment_less_code and i < len(new_body) - 1: |
| | |
| | new_body.pop(i) |
| | new_body.append(node) |
| | break |
| | return new_body |
| |
|
| | def _fix_init_location(self, new_body, original_body): |
| | """ |
| | Fix the location of the `super().__init__()` in the new body, if we had new statements before it. |
| | If the original class' `super().__init__()` is not in the beginning, do not fix it and leave where it is. |
| | In some cases we do not want to call super() at the very beginning. |
| | """ |
| | start_index = 0 |
| | for i, node in enumerate(original_body): |
| | if m.matches(node, DOCSTRING_NODE) and i == start_index: |
| | start_index += 1 |
| | continue |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if "super().__init__" in comment_less_code and i > start_index: |
| | return new_body |
| |
|
| | start_index = 0 |
| | for i, node in enumerate(new_body): |
| | if m.matches(node, DOCSTRING_NODE) and i == start_index: |
| | start_index += 1 |
| | continue |
| | code = self.python_module.code_for_node(node) |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
| | if "super().__init__" in comment_less_code and i > start_index: |
| | |
| | node = new_body.pop(i) |
| | new_body = new_body[:start_index] + [node] + new_body[start_index:] |
| | break |
| | return new_body |
| |
|
| | def is_call_to_super(self, node: cst.BaseStatement, func_name: str): |
| | """Check whether `node` corresponds to a call to `super().func_name(...)`""" |
| | super_call_node = m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) |
| | return m.matches(node, m.SimpleStatementLine(body=[m.Return(super_call_node) | m.Expr(super_call_node)])) |
| |
|
| | def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: |
| | func_name = updated_node.name.value |
| | self.should_check_statements = False |
| | if func_name in self.modular_methods: |
| | actual_body = updated_node.body.body |
| | new_body = [] |
| | for i, base_statement_node in enumerate(actual_body): |
| | if self.is_call_to_super(base_statement_node, func_name): |
| | original_modeling_method_body = self.original_modeling_methods[func_name].body.body |
| | new_body.extend(self.update_body(original_modeling_method_body, actual_body[i + 1 :])) |
| | new_body = self._fix_init_location(new_body, original_modeling_method_body) |
| | |
| | break |
| | |
| | |
| | new_body.append(base_statement_node.visit(self.parent_class_call_transformer)) |
| | return updated_node.with_changes(body=updated_node.body.with_changes(body=new_body)) |
| | return updated_node |
| |
|
| |
|
| | def find_all_dependencies( |
| | dependency_mapping: dict[str, set], |
| | start_entity: str | None = None, |
| | initial_dependencies: set | None = None, |
| | initial_checked_dependencies: set | None = None, |
| | return_parent: bool = False, |
| | ) -> list | set: |
| | """Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of |
| | BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. |
| | |
| | Args: |
| | dependency_mapping (`Dict[str, set]`): |
| | A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, |
| | a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called |
| | in `foo`'s definition. |
| | start_entity (str | None, *optional*): |
| | A key of `dependency_mapping`, indicating from which entity to start the search. |
| | initial_dependencies (set | None, *optional*): |
| | If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue |
| | from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. |
| | initial_checked_dependencies (set | None, *optional*): |
| | If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. |
| | return_parent (bool, *optional*): |
| | If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note |
| | that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before children. |
| | Returns: |
| | A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. |
| | |
| | Example: |
| | Given the following structure in the `modular_xxx.py` file: |
| | ``` |
| | def foo1(): |
| | pass |
| | |
| | def foo2(): |
| | pass |
| | |
| | def bar(): |
| | foo1() |
| | |
| | def foobar(): |
| | bar() |
| | foo2() |
| | |
| | class MyLayer(SomeOtherModelLayer): |
| | def forward(...): |
| | foobar() |
| | ``` |
| | and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: |
| | ``` |
| | dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} |
| | find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) |
| | >>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] |
| | ``` |
| | That is, all the functions needed (and potentially their immediate parent) so that the function to be added |
| | in MyLayer (`foobar`) can work correctly. |
| | """ |
| | if initial_dependencies is None and start_entity is not None: |
| | initial_dependencies = dependency_mapping[start_entity] |
| | if initial_checked_dependencies is None: |
| | initial_checked_dependencies = set() |
| |
|
| | dependency_queue = deque(initial_dependencies) |
| | all_dependencies = set() |
| | all_dependencies_with_parent = [] |
| | checked_dependencies = set(initial_checked_dependencies) |
| | parents = dict.fromkeys(initial_dependencies, start_entity) |
| | while len(dependency_queue) > 0: |
| | |
| | current = dependency_queue.popleft() |
| | if current not in checked_dependencies: |
| | |
| | all_dependencies.add(current) |
| | all_dependencies_with_parent += [(current, parents[current])] |
| | if current in dependency_mapping: |
| | |
| | dependency_queue.extend(dependency_mapping[current]) |
| | parents.update(dict.fromkeys(dependency_mapping[current], current)) |
| | |
| | checked_dependencies.add(current) |
| |
|
| | if not return_parent: |
| | return all_dependencies |
| | |
| | return all_dependencies_with_parent |
| |
|
| |
|
| | |
| | ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"] |
| |
|
| | |
| | ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"] |
| |
|
| |
|
| | class ClassDependencyMapper(CSTVisitor): |
| | """A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of |
| | `global_names`. |
| | """ |
| |
|
| | def __init__( |
| | self, class_name: str, global_names: set[str], objects_imported_from_modeling: set[str] | None = None |
| | ): |
| | super().__init__() |
| | self.class_name = class_name |
| | self.dependencies = set() |
| | self.global_names = global_names |
| | self.objects_imported_from_modeling = ( |
| | set() if objects_imported_from_modeling is None else objects_imported_from_modeling |
| | ) |
| |
|
| | def visit_Name(self, node): |
| | if ( |
| | node.value != self.class_name |
| | and node.value in self.global_names |
| | and node.value not in self.objects_imported_from_modeling |
| | ): |
| | self.dependencies.add(node.value) |
| |
|
| |
|
| | def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: |
| | """Create immediate dependencies for a class node based on the `global_names`.""" |
| | temp_module = cst.Module(body=[node]) |
| | visitor = ClassDependencyMapper(node.name.value, global_names) |
| | temp_module.visit(visitor) |
| | return visitor.dependencies |
| |
|
| |
|
| | def augmented_dependencies_for_class_node( |
| | node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: set[str] | None = None |
| | ) -> set: |
| | """Create augmented dependencies for a class node based on a `mapper`. |
| | Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. |
| | """ |
| | temp_module = cst.Module(body=[node]) |
| | visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) |
| | temp_module.visit(visitor) |
| | return mapper.augment_dependencies(visitor.dependencies) |
| |
|
| |
|
| | |
| | ALL_FILE_TYPES = ( |
| | "modeling", |
| | "configuration", |
| | "tokenization", |
| | "processing", |
| | "image_processing.*_fast", |
| | "image_processing", |
| | "video_processing", |
| | "feature_extraction", |
| | ) |
| |
|
| |
|
| | class ModuleMapper(CSTVisitor, ABC): |
| | """An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. |
| | Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in |
| | `self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). |
| | It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the |
| | modeling files that will be visited. |
| | """ |
| |
|
| | METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) |
| |
|
| | def __init__(self, python_module: cst.Module): |
| | |
| | self.python_module: cst.Module = python_module |
| | self.classes: dict[str, cst.ClassDef] = {} |
| | self.imports = [] |
| | self.functions: dict[str, cst.FunctionDef] = {} |
| | self.object_dependency_mapping = defaultdict(set) |
| | self.assignments: dict[str, cst.SimpleStatementLine] = {} |
| | self.current_function = None |
| | self.current_class = None |
| | self.current_assignment = None |
| | |
| | self.objects_imported_from_modeling = set() |
| | |
| | self.match_patterns = "|".join(ALL_FILE_TYPES) |
| | |
| |
|
| | def visit_ImportFrom(self, node): |
| | """This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have |
| | `from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs |
| | to be added (because it will be part of the imports)""" |
| | |
| | import_module = self.python_module.code_for_node(node.module) if node.module is not None else "" |
| | import_statement = "." * len(node.relative) + import_module |
| | if re.search(rf"^\.({self.match_patterns}).*", import_statement): |
| | for imported_object in node.names: |
| | |
| | if imported_object.evaluated_alias is not None: |
| | self.objects_imported_from_modeling.add(imported_object.evaluated_alias) |
| | else: |
| | self.objects_imported_from_modeling.add(imported_object.evaluated_name) |
| |
|
| | def visit_SimpleStatementLine(self, node): |
| | """ |
| | Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements |
| | are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. |
| | """ |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | simple_top_level_assign_structure = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
| | ) |
| | simple_top_level_variable_indexing = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Subscript(value=m.Name()) | m.Attribute(value=m.Name()))])] |
| | ) |
| |
|
| | if m.matches(parent_node, m.Module()): |
| | if m.matches(node, simple_top_level_assign_structure): |
| | left_hand_side = node.body[0].targets[0].target.value |
| | self.current_assignment = left_hand_side |
| | self.assignments[left_hand_side] = node |
| | |
| | elif m.matches(node, simple_top_level_variable_indexing): |
| | indexed_variable = node.body[0].targets[0].target.value.value |
| | |
| | self.current_assignment = indexed_variable |
| | |
| | node_name = self.python_module.code_for_node(node) |
| | self.assignments[node_name] = node |
| | self.object_dependency_mapping[indexed_variable].add(node_name) |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): |
| | self.imports.append(node) |
| |
|
| | def leave_SimpleStatementLine(self, node): |
| | |
| | |
| | self.current_assignment = None |
| |
|
| | def visit_FunctionDef(self, node): |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | if m.matches(parent_node, m.Module()): |
| | self.current_function = node.name.value |
| | self.functions[node.name.value] = node |
| |
|
| | def leave_FunctionDef(self, node): |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | if m.matches(parent_node, m.Module()): |
| | self.current_function = None |
| |
|
| | def visit_If(self, node): |
| | |
| | if self.current_function is None and self.current_class is None: |
| | for stmt in node.body.body: |
| | if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): |
| | self.imports.append(node) |
| |
|
| | def visit_ClassDef(self, node: ClassDef) -> None: |
| | """Record class nodes to create their dependencies at the end.""" |
| | self.classes[node.name.value] = node |
| | self.current_class = node.name.value |
| |
|
| | def leave_ClassDef(self, node): |
| | self.current_class = None |
| |
|
| | def visit_Name(self, node: cst.Call): |
| | """This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" |
| | if self.current_function is not None: |
| | self.object_dependency_mapping[self.current_function].add(node.value) |
| | if self.current_assignment is not None: |
| | self.object_dependency_mapping[self.current_assignment].add(node.value) |
| |
|
| | def leave_Module(self, node): |
| | """When leaving the module, we store the position of each global scoped node to allow sorting the dependencies |
| | based on their position in the code later. We use the PositionProvider metadata wrapper for this. |
| | We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in |
| | `self.global_nodes`. |
| | """ |
| | |
| | self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
| | |
| | self.start_lines = {} |
| | for id, node in self.global_nodes.items(): |
| | self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line |
| |
|
| | def _restrict_dependencies_to_known_entities(self): |
| | """Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that |
| | are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). |
| | This should be called only after all merging operations have been finalized!!""" |
| | global_objects = set(self.global_nodes.keys()) |
| | for object_name, dependencies in self.object_dependency_mapping.items(): |
| | self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} |
| |
|
| | def _compute_recursive_object_dependencies(self) -> dict[str, set]: |
| | """Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the |
| | following file: |
| | ``` |
| | def foo(): |
| | pass |
| | |
| | def bar(): |
| | foo() |
| | |
| | def test(): |
| | bar() |
| | ``` |
| | this visitor can only record immediate dependencies, i.e. it will record the following |
| | `self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create |
| | the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. |
| | """ |
| | recursive_dependencies = {} |
| | for object_name in self.object_dependency_mapping: |
| | all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) |
| | recursive_dependencies[object_name] = all_dependencies |
| | return recursive_dependencies |
| |
|
| | def augment_dependencies(self, dependencies: set[str]) -> set[str]: |
| | """For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and |
| | **assignments** present in the `dependencies`. |
| | """ |
| | new_dependencies = dependencies.copy() |
| | |
| | for dep in tuple(dependencies): |
| | if dep in self.object_recursive_dependency_mapping: |
| | new_dependencies.update(self.object_recursive_dependency_mapping[dep]) |
| | return new_dependencies |
| |
|
| | def compute_class_dependencies(self): |
| | """For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" |
| | self.class_dependency_mapping = {} |
| | for class_name, class_node in self.classes.items(): |
| | dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) |
| | |
| | self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) |
| |
|
| | @abstractmethod |
| | def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
| | raise NotImplementedError |
| |
|
| |
|
| | class ModelFileMapper(ModuleMapper): |
| | """A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file |
| | in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. |
| | For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes |
| | care of correctly merging dependencies, then finalizes all dependency graph computations. |
| | Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. |
| | For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies |
| | of the modeling files as well. |
| | """ |
| |
|
| | def __init__(self, python_module: cst.Module): |
| | super().__init__(python_module) |
| |
|
| | def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: |
| | """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
| | will be created based on the modular. |
| | """ |
| | relative_order = {} |
| | idx = 0 |
| | classes = sorted( |
| | [dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] |
| | ) |
| | |
| | |
| | if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): |
| | raise ValueError("Cannot correctly find the relative order of the dependencies.") |
| |
|
| | remaining_dependencies = missing_dependencies.copy() |
| |
|
| | |
| | for class_name in classes: |
| | class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) |
| | original_dependencies = [] |
| | merged_dependencies = [] |
| | |
| | |
| | for class_dep in class_dependencies: |
| | if class_dep in self.start_lines: |
| | original_dependencies.append(class_dep) |
| | else: |
| | merged_dependencies.append(class_dep) |
| | |
| | |
| | original_dependencies = sorted(original_dependencies, reverse=True) |
| | |
| | original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
| | merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
| |
|
| | |
| | for dep in original_dependencies + merged_dependencies: |
| | remaining_dependencies.remove(dep) |
| | relative_order[dep] = idx |
| | idx += 1 |
| | |
| | |
| | if class_name in remaining_dependencies: |
| | remaining_dependencies.remove(class_name) |
| | relative_order[class_name] = idx |
| | idx += 1 |
| |
|
| | |
| | remaining_dependencies = tuple(remaining_dependencies) |
| | original_dependencies = [] |
| | merged_dependencies = [] |
| | for dep in remaining_dependencies: |
| | if dep in self.modular_file_start_lines: |
| | merged_dependencies.append(dep) |
| | else: |
| | original_dependencies.append(dep) |
| | |
| | |
| | original_dependencies = sorted(original_dependencies, reverse=True) |
| | |
| | original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
| | merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
| |
|
| | |
| | for dep in original_dependencies + merged_dependencies: |
| | relative_order[dep] = idx |
| | idx += 1 |
| |
|
| | return relative_order |
| |
|
| | def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
| | """Update the global nodes and function dependency mapping with those from the modular file. |
| | |
| | Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies |
| | instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). |
| | """ |
| | |
| | self.functions.update(functions) |
| | self.object_dependency_mapping.update({obj: dep for obj, dep in object_mapping.items() if obj in functions}) |
| | |
| | self.global_nodes.update(self.functions) |
| |
|
| | def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
| | """Update the global nodes with the assignment from the modular file. |
| | |
| | Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches |
| | a pattern in `ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE` and its value is not None, or if it matches a pattern in `ASSIGNMENTS_REGEX_TO_KEEP. |
| | Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the big docstrings. |
| | """ |
| | for assignment, node in assignments.items(): |
| | should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP) |
| |
|
| | should_keep_if_not_none = any( |
| | re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE |
| | ) and not (hasattr(node.body[0].value, "value") and node.body[0].value.value == "None") |
| |
|
| | if should_keep or should_keep_if_not_none or assignment not in self.assignments: |
| | self.assignments[assignment] = node |
| | if assignment in object_mapping: |
| | self.object_dependency_mapping[assignment] = object_mapping[assignment] |
| | |
| | self.global_nodes.update(self.assignments) |
| |
|
| | def _merge_classes(self, classes: dict[str, cst.CSTNode]): |
| | """Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and |
| | are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined |
| | classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we |
| | do not add the new classes to `self.classes`, but only to `global_nodes`. |
| | """ |
| | |
| | self.global_nodes.update( |
| | { |
| | name: node |
| | for name, node in classes.items() |
| | if name not in self.classes and name not in self.objects_imported_from_modeling |
| | } |
| | ) |
| |
|
| | def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): |
| | """Merge classes, functions and assignments from the modular definitions into the current module file, |
| | then record the relative order of all nodes. |
| | Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the |
| | merge with other files dependencies. |
| | """ |
| | self._merge_functions(functions, object_mapping) |
| | self._merge_assignments(assignments, object_mapping) |
| | self._merge_classes(classes) |
| | self.modular_file_start_lines = start_lines |
| |
|
| | |
| | self._restrict_dependencies_to_known_entities() |
| | |
| | self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
| |
|
| | @classmethod |
| | def visit_and_merge_dependencies( |
| | cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines |
| | ) -> "ModelFileMapper": |
| | wrapper = MetadataWrapper(module) |
| | mapper = cls(module) |
| | wrapper.visit(mapper) |
| | |
| | mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) |
| | |
| | mapper.compute_class_dependencies() |
| | return mapper |
| |
|
| |
|
| | def common_partial_suffix(str1: str, str2: str) -> str: |
| | """Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string, |
| | we do not consider it a common suffix and return `""`""" |
| | common_suffix = "" |
| | for i in range(1, min(len(str1), len(str2)) + 1): |
| | if str1[-i] == str2[-i]: |
| | common_suffix = str1[-i] + common_suffix |
| | else: |
| | break |
| | |
| | if common_suffix == str1 or common_suffix == str2: |
| | common_suffix = "" |
| | return common_suffix |
| |
|
| |
|
| | def replace_class_node( |
| | mapper: ModelFileMapper, modular_class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str |
| | ) -> cst.ClassDef: |
| | """ |
| | Replace a class node which inherits from another modeling class. This function works in the following way: |
| | - start from the methods and class attributes of the original modeling code node, and replace their definition |
| | if overridden in the modular |
| | - append all new methods and class attributes defined in the child class |
| | - all potential method/class docstrings and decorators use the ones found in modular if any, else in original modeling |
| | - replace all calls to super() with the unravelled code |
| | |
| | Args: |
| | mapper (`ModelFileMapper`): |
| | The mapper corresponding to the visited file from which the modular class node inherits. |
| | modular_class_node (`cst.ClassDef`): |
| | The class node as found in the modular file. |
| | renamed_super_class (`str`): |
| | The name of the class from which `modular_class_node` inherits after automatic renaming. |
| | original_super_class (`str`): |
| | The name of the class from which `modular_class_node` inherits before automatic renaming. |
| | |
| | Returns: |
| | A new class node corresponding to the modular definition. |
| | """ |
| | all_bases = [get_full_attribute_name(k.value) for k in modular_class_node.bases] |
| | if any(base is None for base in all_bases): |
| | raise ValueError(f"Could not parse the name of the bases for {modular_class_node.name.value}") |
| |
|
| | original_modeling_node = mapper.classes[renamed_super_class] |
| | |
| | new_class_name = modular_class_node.name |
| |
|
| | |
| | if new_class_name.value != renamed_super_class: |
| | common_suffix = common_partial_suffix(new_class_name.value, renamed_super_class) |
| | |
| | old, new = renamed_super_class.replace(common_suffix, ""), new_class_name.value.replace(common_suffix, "") |
| | temp_module = cst.Module(body=[original_modeling_node]) |
| | original_modeling_node = temp_module.visit( |
| | ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True) |
| | ).body[0] |
| |
|
| | |
| | |
| | additional_bases = [base for base in all_bases if base != original_super_class] |
| | new_class_bases = [] |
| | for original_base in original_modeling_node.bases: |
| | new_base = original_base |
| | |
| | if m.matches(original_base.value, m.Name()): |
| | original_base_name = original_base.value.value |
| | for additional_base_name in additional_bases: |
| | suffix = common_partial_suffix(original_base_name, additional_base_name) |
| | if len(suffix) > 0 and suffix[0].isupper(): |
| | new_name_node = original_base.value.with_changes(value=additional_base_name) |
| | new_base = original_base.with_changes(value=new_name_node) |
| | break |
| | new_class_bases.append(new_base) |
| |
|
| | |
| | new_class_decorators = ( |
| | modular_class_node.decorators if len(modular_class_node.decorators) > 0 else original_modeling_node.decorators |
| | ) |
| |
|
| | |
| | original_modeling_docstring = [ |
| | node for node in original_modeling_node.body.body if m.matches(node, DOCSTRING_NODE) |
| | ] |
| | modular_docstring = [node for node in modular_class_node.body.body if m.matches(node, DOCSTRING_NODE)] |
| | |
| | new_class_docstring = modular_docstring if len(modular_docstring) > 0 else original_modeling_docstring |
| |
|
| | |
| | original_modeling_class_attributes = {} |
| | for node in original_modeling_node.body.body: |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): |
| | original_modeling_class_attributes[node.body[0].targets[0].target.value] = node |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()])): |
| | original_modeling_class_attributes[node.body[0].target.value] = node |
| |
|
| | modular_class_attributes = {} |
| | for node in modular_class_node.body.body: |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): |
| | modular_class_attributes[node.body[0].targets[0].target.value] = node |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.AnnAssign()])): |
| | modular_class_attributes[node.body[0].target.value] = node |
| |
|
| | |
| | new_class_attributes = list({**original_modeling_class_attributes, **modular_class_attributes}.values()) |
| |
|
| | |
| | original_modeling_methods = {} |
| | for node in original_modeling_node.body.body: |
| | if m.matches(node, m.FunctionDef()): |
| | |
| | |
| | if node.name.value in original_modeling_methods: |
| | |
| | if node.decorators[0].decorator.value == "property": |
| | original_modeling_methods[f"{node.name.value}_setter"] = original_modeling_methods[node.name.value] |
| | original_modeling_methods[node.name.value] = node |
| | |
| | else: |
| | original_modeling_methods[f"{node.name.value}_setter"] = node |
| | else: |
| | original_modeling_methods[node.name.value] = node |
| | modular_methods = {} |
| | for node in modular_class_node.body.body: |
| | if m.matches(node, m.FunctionDef()): |
| | |
| | |
| | if node.name.value in modular_methods: |
| | |
| | if node.decorators[0].decorator.value == "property": |
| | modular_methods[f"{node.name.value}_setter"] = modular_methods[node.name.value] |
| | modular_methods[node.name.value] = node |
| | |
| | else: |
| | modular_methods[f"{node.name.value}_setter"] = node |
| | else: |
| | modular_methods[node.name.value] = node |
| |
|
| | new_class_methods = [] |
| | |
| | for name, node in original_modeling_methods.items(): |
| | |
| | if name in modular_methods: |
| | |
| | modular_node = modular_methods[name] |
| |
|
| | |
| | if re.match(r"\ndef .*\(.*\):\n raise.*Error\(.*", mapper.python_module.code_for_node(modular_node)): |
| | continue |
| |
|
| | |
| | modeling_docstring = [node_ for node_ in node.body.body if m.matches(node_, DOCSTRING_NODE)] |
| | modular_docstring = [node_ for node_ in modular_node.body.body if m.matches(node_, DOCSTRING_NODE)] |
| | |
| | new_body = ( |
| | modular_node.body.body |
| | if len(modular_docstring) > 0 |
| | else modeling_docstring + list(modular_node.body.body) |
| | ) |
| | new_body = modular_node.body.with_changes(body=new_body) |
| |
|
| | |
| | new_params = modular_node.params |
| |
|
| | |
| | kwarg_name = getattr(modular_node.params, "star_kwarg", None) |
| | if kwarg_name and kwarg_name.name.value == "super_kwargs": |
| | original_modeling_params = {k.name.value: k for k in node.params.params} |
| | modular_params = {k.name.value: k for k in new_params.params[1:]} |
| | new_param_list = list({**original_modeling_params, **modular_params}.values()) |
| | new_params = new_params.with_changes(params=new_param_list, star_kwarg=node.params.star_kwarg) |
| |
|
| | |
| | new_decorators = modular_node.decorators if len(modular_node.decorators) > 0 else node.decorators |
| |
|
| | |
| | new_return_annotation = modular_node.returns if modular_node.returns else node.returns |
| |
|
| | |
| | node = node.with_changes( |
| | body=new_body, |
| | params=new_params, |
| | decorators=new_decorators, |
| | returns=new_return_annotation, |
| | ) |
| |
|
| | new_class_methods.append(node) |
| |
|
| | |
| | for name, node in modular_methods.items(): |
| | if name not in original_modeling_methods: |
| | new_class_methods.append(node) |
| |
|
| | |
| | new_class_body = new_class_docstring + new_class_attributes + new_class_methods |
| |
|
| | |
| | |
| | result_node = original_modeling_node.with_changes(body=cst.IndentedBlock(body=new_class_body)) |
| | temp_module = cst.Module(body=[result_node]) |
| | new_replacement_class = temp_module.visit( |
| | ReplaceSuperCallTransformer(temp_module, original_modeling_methods, modular_methods, new_class_bases) |
| | ) |
| | new_class_body = new_replacement_class.body[0].body |
| |
|
| | return original_modeling_node.with_changes( |
| | body=new_class_body, decorators=new_class_decorators, bases=new_class_bases, name=new_class_name |
| | ) |
| |
|
| |
|
| | TYPE_TO_FILE_TYPE = { |
| | "Config": "configuration", |
| | "Tokenizer": "tokenization", |
| | "Processor": "processing", |
| | "ImageProcessor": "image_processing", |
| | "ImageProcessorFast": "image_processing.*_fast", |
| | "VideoProcessor": "video_processing", |
| | "VideoProcessorInitKwargs": "video_processing", |
| | "FastImageProcessorKwargs": "image_processing.*_fast", |
| | "ImageProcessorKwargs": "image_processing", |
| | "FeatureExtractor": "feature_extraction", |
| | "ProcessorKwargs": "processing", |
| | "VideosKwargs": "processing", |
| | "ImagesKwargs": "processing", |
| | "TextKwargs": "processing", |
| | } |
| |
|
| |
|
| | def find_file_type(class_name: str, model_name: str) -> str: |
| | """Based on a class name, find the file type corresponding to the class. |
| | If the class name is `LlamaConfig` it will return `configuration`. |
| | The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` |
| | """ |
| | match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) |
| | |
| | |
| | match = re.search(rf"({match_pattern})$", class_name.replace(get_cased_name(model_name), "")) |
| | if match: |
| | file_type = TYPE_TO_FILE_TYPE[match.group(1)] |
| | else: |
| | file_type = "modeling" |
| | return file_type |
| |
|
| |
|
| | |
| | |
| | VARIABLES_AT_THE_BEGINNING = ( |
| | "logger", |
| | "_CHECKPOINT_FOR_DOC", |
| | "_CONFIG_FOR_DOC", |
| | ) |
| |
|
| | |
| | IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) |
| |
|
| |
|
| | def append_new_import_node( |
| | node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode] |
| | ): |
| | """Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`. |
| | Also modifies `added_names` in-place accordingly.""" |
| | import_node = node.body[0] |
| | names_to_keep = [] |
| | for name in import_node.names: |
| | name_value = name.evaluated_alias or name.evaluated_name |
| | if name_value not in unused_imports and name_value not in added_names: |
| | names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) |
| | added_names.add(name_value) |
| | if len(names_to_keep) > 0: |
| | new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) |
| | imports_to_keep.append(new_node) |
| |
|
| |
|
| | def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: |
| | """Get all the imports needed in the `body`, from the list of `all_imports`. |
| | `body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. |
| | Note: we need to use `isinstance` on scope assignments, m.matches apparently does not work here yet! |
| | """ |
| | new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
| | wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) |
| | scopes = set(wrapper.resolve(ScopeProvider).values()) |
| | unused_imports = set() |
| | import_ref_count = defaultdict(lambda: 0) |
| | for scope in scopes: |
| | for assignment in scope.assignments: |
| | node = assignment.node |
| | if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): |
| | ref_count = len(assignment.references) |
| | name = assignment.name |
| | import_ref_count[name] = max(ref_count, import_ref_count[name]) |
| | |
| | |
| | unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body} |
| |
|
| | imports_to_keep = [] |
| | |
| | |
| | added_names = set() |
| | existing_protected_statements = set() |
| | for node in all_imports: |
| | if m.matches(node, m.If()): |
| | new_statements = [] |
| | for stmt_node in node.body.body: |
| | append_new_import_node(stmt_node, unused_imports, added_names, new_statements) |
| | new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] |
| | if len(new_statements) > 0: |
| | new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) |
| | imports_to_keep.append(new_node) |
| | existing_protected_statements.update({str(stmt) for stmt in new_statements}) |
| | else: |
| | append_new_import_node(node, unused_imports, added_names, imports_to_keep) |
| |
|
| | protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] |
| | usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] |
| |
|
| | |
| | return usual_import_nodes + protected_import_nodes |
| |
|
| |
|
| | def split_all_assignment(node: cst.CSTNode, model_name: str) -> dict[str, cst.CSTNode]: |
| | """Split the `__all__` assignment found in the modular between each corresponding files.""" |
| | all_all_per_file = {} |
| | assign_node = node.body[0] |
| | if isinstance(assign_node.value, cst.List): |
| | |
| | all_all_to_add = defaultdict(list) |
| | for element in assign_node.value.elements: |
| | if isinstance(element.value, cst.SimpleString): |
| | |
| | class_name = element.value.value |
| | file = find_file_type(element.value.evaluated_value, model_name) |
| | all_all_to_add[file] += [class_name] |
| | for file, new_alls in all_all_to_add.items(): |
| | new_node = assign_node.with_changes( |
| | value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) |
| | ) |
| | all_all_per_file[file] = node.with_changes(body=[new_node]) |
| | return all_all_per_file |
| |
|
| |
|
| | class ModularFileMapper(ModuleMapper): |
| | """This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, |
| | then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. |
| | Calling the method `create_modules()` after visit will create all modules based on this modular file. |
| | """ |
| |
|
| | def __init__(self, python_module, new_name, package_name): |
| | super().__init__(python_module) |
| | |
| | self.model_name = new_name |
| |
|
| | self.model_specific_imported_objects: dict[str, str] = {} |
| | self.model_specific_modules: dict[str, cst.Module] = {} |
| |
|
| | self.all_all_to_add = {} |
| |
|
| | self.excluded_external_files = {} if package_name == "transformers" else EXCLUDED_EXTERNAL_FILES[package_name] |
| | |
| |
|
| | def visit_ImportFrom(self, node: cst.ImportFrom) -> None: |
| | """When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, |
| | and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. |
| | """ |
| | |
| | import_module = self.python_module.code_for_node(node.module) if node.module is not None else "" |
| | import_statement = "." * len(node.relative) + import_module |
| | if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): |
| | return |
| | if m.matches(node.module, m.Attribute()): |
| | for imported_ in node.names: |
| | |
| | if any(external_file["name"] in import_statement for external_file in self.excluded_external_files): |
| | continue |
| | _import = re.search( |
| | rf"(?:transformers\.models\.)|(?:\.\.\.models\.)|(?:\.\.)\w+\.({self.match_patterns}).*", |
| | import_statement, |
| | ) |
| | if _import: |
| | source = _import.group(1) |
| | if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): |
| | raise ValueError( |
| | f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" |
| | ) |
| | if import_module not in self.model_specific_modules: |
| | if "models" not in import_module: |
| | import_module = "models." + import_module |
| | if not import_module.startswith("transformers"): |
| | import_module = "transformers." + import_module |
| | try: |
| | source_code = get_module_source_from_name(import_module) |
| | except ModuleNotFoundError as e: |
| | raise ModuleNotFoundError( |
| | f"Failed to visit import from for: {self.python_module.code_for_node(node)}. Tried to import {import_module} but failed." |
| | ) from e |
| | tree = cst.parse_module(source_code) |
| | self.model_specific_modules[import_module] = tree |
| | imported_object = self.python_module.code_for_node(imported_.name) |
| | self.model_specific_imported_objects[imported_object] = import_module |
| | if m.matches(node.module, m.Name()): |
| | if import_module == "transformers": |
| | raise ValueError( |
| | f"You are importing from {import_module} directly using global imports. Import from the correct local path" |
| | ) |
| |
|
| | def visit_SimpleStatementLine(self, node): |
| | """If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, |
| | simply record it or, if it is `__all__`, split it between files where we should dispatch it. |
| | """ |
| | parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
| | simple_top_level_assign_structure = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
| | ) |
| | simple_top_level_variable_indexing = m.SimpleStatementLine( |
| | body=[m.Assign(targets=[m.AssignTarget(target=m.Subscript(value=m.Name()) | m.Attribute(value=m.Name()))])] |
| | ) |
| |
|
| | if m.matches(parent_node, m.Module()): |
| | if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): |
| | self.imports.append(node) |
| | elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): |
| | |
| | import_module = ( |
| | self.python_module.code_for_node(node.body[0].module) if node.body[0].module is not None else "" |
| | ) |
| | import_statement = "." * len(node.body[0].relative) + import_module |
| | if any( |
| | external_file["name"] in import_statement for external_file in self.excluded_external_files |
| | ) or not ( |
| | re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns}).*", import_statement) |
| | and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) |
| | ): |
| | self.imports.append(node) |
| | elif m.matches(node, simple_top_level_assign_structure): |
| | assigned_variable = node.body[0].targets[0].target.value |
| | |
| | if assigned_variable == "__all__": |
| | self.all_all_to_add = split_all_assignment(node, self.model_name) |
| | else: |
| | self.current_assignment = assigned_variable |
| | self.assignments[assigned_variable] = node |
| | |
| | elif m.matches(node, simple_top_level_variable_indexing): |
| | indexed_variable = node.body[0].targets[0].target.value.value |
| | |
| | self.current_assignment = indexed_variable |
| | |
| | node_name = self.python_module.code_for_node(node) |
| | self.assignments[node_name] = node |
| | self.object_dependency_mapping[indexed_variable].add(node_name) |
| |
|
| | def leave_Module(self, node): |
| | """When we leave the modular file, we do the following in order: |
| | 1. for each modeling file found in the imports, rename it with the new model name, visit it, and update |
| | its dependency graph with the new function and assignment definitions found in the modular |
| | 2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) |
| | 3. compute the nested (recursive) function and assignment dependencies |
| | """ |
| | |
| | super().leave_Module(node) |
| |
|
| | |
| | self.visited_modules = {} |
| | self.renamers = {} |
| | name_prefixes = self.infer_new_model_name() |
| | for file, module in self.model_specific_modules.items(): |
| | file_model_name = file.split(".")[-2] |
| | new_name = name_prefixes[file] |
| | renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name) |
| | renamed_module = module.visit(renamer) |
| | self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( |
| | renamed_module, |
| | self.classes, |
| | self.functions, |
| | self.assignments, |
| | self.object_dependency_mapping, |
| | self.start_lines, |
| | ) |
| | |
| | self.renamers[file] = renamer |
| |
|
| | |
| | |
| | self.merge_model_specific_imports(self.visited_modules) |
| |
|
| | |
| | self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
| |
|
| | |
| | |
| | self.imported_objects_per_file = defaultdict(set) |
| | for file, mapper in self.visited_modules.items(): |
| | file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})", file).group(1) |
| |
|
| | |
| | if self.excluded_external_files: |
| | for excluded_file in self.excluded_external_files: |
| | if file.split(".")[-1] == excluded_file["name"]: |
| | file_type = excluded_file["type"] |
| | break |
| |
|
| | self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) |
| |
|
| | def merge_model_specific_imports(self, visited_modules): |
| | """Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, |
| | based on the visited files.""" |
| | self.start_lines_file_mapping = {} |
| | self.added_objects_file_mapping = {} |
| | for object_name, file in self.model_specific_imported_objects.items(): |
| | visited_module = visited_modules[file] |
| | self.start_lines_file_mapping[file] = visited_module.start_lines |
| | |
| | if object_name in visited_module.functions and object_name not in self.functions: |
| | self.functions[object_name] = visited_module.functions[object_name] |
| | self.added_objects_file_mapping[object_name] = file |
| | dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
| | if dependencies is not None: |
| | self.object_dependency_mapping[object_name] = dependencies |
| | for dep in dependencies: |
| | if dep not in self.global_nodes: |
| | self.added_objects_file_mapping[dep] = file |
| | self.functions[dep] = visited_module.global_nodes[dep] |
| |
|
| | |
| | |
| | |
| | recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set()) |
| | node_recursive_dependencies_mapping = { |
| | dep: visited_module.global_nodes[dep] for dep in recursive_dependencies |
| | } |
| | for filename, module_mapper in self.visited_modules.items(): |
| | if filename != file: |
| | module_mapper.global_nodes[object_name] = visited_module.functions[object_name] |
| | if len(recursive_dependencies) > 0: |
| | module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies |
| | module_mapper.global_nodes.update(node_recursive_dependencies_mapping) |
| |
|
| | |
| | elif object_name in visited_module.assignments and object_name not in self.assignments: |
| | self.assignments[object_name] = visited_module.assignments[object_name] |
| | self.added_objects_file_mapping[object_name] = file |
| | dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
| | if dependencies is not None: |
| | self.object_dependency_mapping[object_name] = dependencies |
| | for dep in dependencies: |
| | if dep not in self.global_nodes: |
| | self.added_objects_file_mapping[dep] = file |
| | self.assignments[dep] = visited_module.global_nodes[dep] |
| |
|
| | |
| | self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
| | |
| | self._restrict_dependencies_to_known_entities() |
| |
|
| | def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
| | """Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
| | will be created based on the modular. |
| | """ |
| | relative_order = {} |
| | idx = 0 |
| |
|
| | original_dependencies = [] |
| | other_files_dependencies = defaultdict(list) |
| | for dep in sorted(missing_dependencies): |
| | if dep in self.added_objects_file_mapping: |
| | file = self.added_objects_file_mapping[dep] |
| | other_files_dependencies[file].append(dep) |
| | else: |
| | original_dependencies.append(dep) |
| | |
| | all_dependencies = [] |
| | for file, dependencies in other_files_dependencies.items(): |
| | sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) |
| | all_dependencies += sorted_dependencies |
| | all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) |
| |
|
| | |
| | for dep in all_dependencies: |
| | relative_order[dep] = idx |
| | idx += 1 |
| |
|
| | return relative_order |
| |
|
| | def infer_new_model_name(self) -> dict: |
| | """Infer whether we are using a model name prefix different from the usual model name as defined from the filename. |
| | This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`, |
| | so we have something like: |
| | ```python |
| | class NewModelNameTextDecoderLayer(LlamaDecoderLayer): |
| | pass |
| | ``` |
| | with the `Text` prefix added to the model name. |
| | However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing |
| | the same file multiple times and inconsistencies in the objects added from dependencies. |
| | If the new prefix collides with a prefix of another class in the file where we are importing from, then we also |
| | raise a warning, and use the default prefix (model name) to avoid collisions in dependencies. |
| | """ |
| | prefix_model_name_mapping = defaultdict(Counter) |
| | cased_default_name = get_cased_name(self.model_name) |
| | |
| | for class_name, class_node in self.classes.items(): |
| | modeling_bases = [ |
| | k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects |
| | ] |
| | if len(modeling_bases) > 1: |
| | raise ValueError( |
| | f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}." |
| | ) |
| | if len(modeling_bases) == 1: |
| | filename = self.model_specific_imported_objects[modeling_bases[0]] |
| | cased_model_name = cased_default_name |
| | suffix = common_partial_suffix(class_name, modeling_bases[0]) |
| | if len(suffix) > 0 and suffix[0].isupper(): |
| | cased_model_name = class_name.replace(suffix, "") |
| | |
| | |
| | if len(cased_model_name) < len(cased_default_name) and cased_default_name in class_name: |
| | cased_model_name = cased_default_name |
| | |
| | |
| | elif class_name.replace(cased_default_name, "") == modeling_bases[0]: |
| | file_model_name = filename.split(".")[-2] |
| | cased_model_name = cased_default_name + get_cased_name(file_model_name) |
| | prefix_model_name_mapping[filename].update([cased_model_name]) |
| |
|
| | |
| | final_name_mapping = {} |
| | for file, prefixes_counter in prefix_model_name_mapping.items(): |
| | if len(prefixes_counter) > 1: |
| | _, total = prefixes_counter.most_common(1)[0] |
| | most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total] |
| | |
| | final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1] |
| | else: |
| | final_name = list(prefixes_counter)[0] |
| | |
| | old_cased_model_name = get_cased_name(file.split(".")[-2]) |
| | old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name) |
| | |
| | has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file) |
| | if final_name != cased_default_name and has_prefix_collision: |
| | if len(prefixes_counter) > 1: |
| | logger.warning( |
| | f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the " |
| | f"most used one, '{final_name}', is already present in the source file and will likely cause consistency " |
| | f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args " |
| | "and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different " |
| | f"from '{cased_default_name}') or use a single prefix in all the modular (best)." |
| | ) |
| | else: |
| | logger.warning( |
| | f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is " |
| | "already present in the source file and will likely cause consistency issues. For this reason we fallback " |
| | f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass " |
| | f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')" |
| | ) |
| | final_name = cased_default_name |
| | elif len(prefixes_counter) > 1: |
| | logger.warning( |
| | f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only " |
| | f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the " |
| | f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix " |
| | "in all the modular (best)." |
| | ) |
| | final_name_mapping[file] = get_lowercase_name(final_name) |
| |
|
| | |
| | for file in self.model_specific_modules: |
| | if file not in final_name_mapping: |
| | final_name_mapping[file] = self.model_name |
| |
|
| | return final_name_mapping |
| |
|
| |
|
| | def check_dependencies_and_create_import_node( |
| | file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str |
| | ) -> tuple[set[str], dict[str, cst.CSTNode]]: |
| | """Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, |
| | we need to remove it from the dependencies, and create a new import to it instead. |
| | This scenario may appear in the following case: |
| | If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` |
| | (e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as |
| | part of the standard dependency graph (because we never encountered an import towards this new class in any file). |
| | For example imagine the following `modular.py`: |
| | ``` |
| | from ..llama.modeling_llama import LlamaModel |
| | |
| | class NewNameTextConfig(PreTrainedConfig): |
| | ... |
| | |
| | class NewNameConfig(PreTrainedConfig): |
| | ... |
| | |
| | class NewNameModel(LlamaModel): |
| | config = NewNameConfig() |
| | text_config = NewNameTextConfig() |
| | ... |
| | ``` |
| | then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as |
| | `configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no |
| | knowledge of `NewNameTextConfig`. |
| | """ |
| | class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} |
| | corrected_dependencies = new_dependencies.copy() |
| | new_imports = {} |
| | for class_name in class_dependencies: |
| | class_file_type = find_file_type(class_name, new_name) |
| | |
| | if class_file_type != file_type: |
| | corrected_dependencies.remove(class_name) |
| | import_statement = f"from .{class_file_type}_{new_name} import {class_name}" |
| | new_imports[class_name] = cst.parse_statement(import_statement) |
| |
|
| | return corrected_dependencies, new_imports |
| |
|
| |
|
| | def get_class_node_and_dependencies( |
| | modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] |
| | ) -> tuple[dict, str, dict]: |
| | """Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new |
| | class node based on the inherited classes if needed. Also returns any new imports of a new class defined in |
| | the modular that we nay need. |
| | """ |
| | |
| | model_specific_bases = [ |
| | k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects |
| | ] |
| | super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None |
| |
|
| | file_type = find_file_type(class_name, modular_mapper.model_name) |
| | file_to_update = files[file_type] |
| | model_name = modular_mapper.model_name |
| |
|
| | |
| | imported_objects = modular_mapper.imported_objects_per_file[file_type] |
| |
|
| | |
| | if super_class is not None: |
| | super_file_name = modular_mapper.model_specific_imported_objects[super_class] |
| |
|
| | |
| | mapper = modular_mapper.visited_modules[super_file_name] |
| | |
| | renamer = modular_mapper.renamers[super_file_name] |
| | renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name) |
| |
|
| | |
| | updated_node = replace_class_node(mapper, node, renamed_super_class, super_class) |
| |
|
| | |
| | new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) |
| |
|
| | |
| | |
| | new_node_dependencies, new_imports = check_dependencies_and_create_import_node( |
| | file_type, new_node_dependencies, mapper, model_name |
| | ) |
| |
|
| | |
| | |
| | |
| | new_node_dependencies -= set(modular_mapper.classes.keys()) |
| |
|
| | |
| | all_dependencies_to_add = find_all_dependencies( |
| | dependency_mapping=mapper.class_dependency_mapping, |
| | initial_dependencies=new_node_dependencies, |
| | initial_checked_dependencies=set(file_to_update.keys()), |
| | ) |
| |
|
| | relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) |
| | nodes_to_add = { |
| | dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add |
| | } |
| |
|
| | |
| | else: |
| | updated_node = node |
| | |
| | |
| | all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) |
| |
|
| | |
| | |
| | all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( |
| | file_type, all_dependencies_to_add, modular_mapper, model_name |
| | ) |
| |
|
| | relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) |
| | nodes_to_add = { |
| | dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) |
| | for dep in all_dependencies_to_add |
| | if dep not in file_to_update |
| | } |
| |
|
| | |
| | class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 |
| | nodes_to_add[class_name] = (class_idx, updated_node) |
| |
|
| | return nodes_to_add, file_type, new_imports |
| |
|
| |
|
| | def create_modules( |
| | modular_mapper: ModularFileMapper, |
| | file_path: str | None = None, |
| | package_name: str | None = "transformers", |
| | ) -> dict[str, cst.Module]: |
| | """Create all the new modules based on visiting the modular file. It replaces all classes as necessary.""" |
| | files = defaultdict(dict) |
| | current_file_indices = defaultdict(lambda: 0) |
| |
|
| | |
| | for class_name, node in modular_mapper.classes.items(): |
| | nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) |
| |
|
| | if package_name != "transformers": |
| | |
| | |
| | |
| | for key, new_import in new_imports.items(): |
| | new_imports[key] = new_import.with_changes( |
| | body=[ |
| | convert_relative_import_to_absolute( |
| | import_node=new_import.body[0], file_path=file_path, package_name=package_name |
| | ) |
| | ] |
| | ) |
| |
|
| | |
| | modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) |
| | modular_mapper.imports.extend(list(new_imports.values())) |
| |
|
| | |
| | nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) |
| | |
| | for dependency, (_, node) in nodes_to_add: |
| | |
| | try: |
| | |
| | idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) |
| | except ValueError: |
| | idx = current_file_indices[file_type] |
| | current_file_indices[file_type] += 1 |
| | files[file_type][dependency] = {"insert_idx": idx, "node": node} |
| |
|
| | |
| | for file_type, node in modular_mapper.all_all_to_add.items(): |
| | idx = current_file_indices[file_type] |
| | files[file_type]["__all__"] = {"insert_idx": idx, "node": node} |
| |
|
| | |
| | |
| | all_imports = modular_mapper.imports.copy() |
| | all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} |
| | for file, mapper in modular_mapper.visited_modules.items(): |
| | new_imports = [ |
| | node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code |
| | ] |
| | new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} |
| | all_imports.extend(new_imports) |
| | all_imports_code.update(new_imports_code) |
| |
|
| | |
| | for file, body in files.items(): |
| | new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
| | needed_imports = get_needed_imports(body, all_imports) |
| |
|
| | if package_name != "transformers": |
| | |
| | for imp in needed_imports: |
| | if m.matches(imp, m.SimpleStatementLine(body=[m.ImportFrom()])): |
| | imp.body[0] = convert_relative_import_to_absolute( |
| | import_node=imp.body[0], file_path=file_path, package_name="transformers" |
| | ) |
| |
|
| | full_module = needed_imports + new_body |
| | new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) |
| | files[file] = new_module |
| |
|
| | return files |
| |
|
| |
|
| | def run_ruff(file: str): |
| | """Run `ruff` linter and formatter on `file`, as in `make style`""" |
| | subprocess.run(["ruff", "check", file, "--fix"], stdout=subprocess.DEVNULL) |
| | subprocess.run(["ruff", "format", file], stdout=subprocess.DEVNULL) |
| |
|
| |
|
| | def convert_modular_file(modular_file: str, source_library: str | None = "transformers") -> dict[str, str]: |
| | """Convert a `modular_file` into all the different model-specific files it depicts.""" |
| | pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) |
| | output = {} |
| | if pattern is not None: |
| | model_name = pattern.groups()[0] |
| | |
| | with open(modular_file, "r", encoding="utf-8") as file: |
| | code = file.read() |
| | module = cst.parse_module(code) |
| |
|
| | |
| | if source_library != "transformers": |
| | relative_path = os.path.abspath(modular_file).replace("\\", "/") |
| | else: |
| | relative_path = re.search( |
| | r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") |
| | ) |
| | if relative_path is None: |
| | raise ValueError( |
| | f"Cannot find the relative path of {modular_file} inside this `transformers` repository. If this modular file is located in another repository and you would like to generate the modeling file there, use the `--external` flag." |
| | ) |
| | relative_path = relative_path.group(1) |
| |
|
| | |
| | if source_library != "transformers": |
| | module = module.visit(AbsoluteImportTransformer(relative_path, source_library)) |
| |
|
| | wrapper = MetadataWrapper(module) |
| | cst_transformers = ModularFileMapper(module, model_name, source_library) |
| | wrapper.visit(cst_transformers) |
| | for file, module in create_modules( |
| | cst_transformers, file_path=relative_path, package_name=source_library |
| | ).items(): |
| | if module != {}: |
| | if source_library != "transformers": |
| | |
| | module = module.visit(RelativeImportTransformer(relative_path, source_library)) |
| |
|
| | header = AUTO_GENERATED_MESSAGE.format( |
| | relative_path=relative_path, short_name=os.path.basename(relative_path) |
| | ) |
| | output[file] = header + module.code |
| | return output |
| | else: |
| | print(f"modular pattern not found in {modular_file}, exiting") |
| | return {} |
| |
|
| |
|
| | def save_modeling_files(modular_file: str, converted_files: dict[str, str]): |
| | """Save all the `converted_files` from the `modular_file`.""" |
| | for file_type in converted_files: |
| | file_name_prefix = file_type.split(".*")[0] |
| | file_name_suffix = file_type.split(".*")[-1] if ".*" in file_type else "" |
| | new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace( |
| | ".py", f"{file_name_suffix}.py" |
| | ) |
| | |
| | with open(new_file_name, "w", encoding="utf-8") as f: |
| | f.write(converted_files[file_type]) |
| | |
| | run_ruff(new_file_name) |
| |
|
| |
|
| | def count_loc(file_path: str) -> int: |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | code = f.read() |
| | comment_less_code = re.sub(r"#.*", "", code).strip() |
| | comment_less_code = re.sub(r" *\n", "\n", comment_less_code).strip() |
| | return len([line for line in comment_less_code.split("\n") if line.strip()]) |
| |
|
| |
|
| | def run_converter(modular_file: str, source_library: str | None = "transformers"): |
| | """Convert a modular file, and save resulting files.""" |
| | print(f"Converting {modular_file} to a single model single file format") |
| | converted_files = convert_modular_file(modular_file, source_library=source_library) |
| | save_modeling_files(modular_file, converted_files) |
| |
|
| | model_directory = os.path.dirname(modular_file) |
| | modular_loc = count_loc(modular_file) |
| |
|
| | autogenerated_files = [] |
| | for file in os.listdir(model_directory): |
| | if file.endswith(".py") and not file.startswith("modular_"): |
| | file_path = os.path.join(model_directory, file) |
| | with open(file_path, "r", encoding="utf-8") as f: |
| | if "This file was automatically generated from" in f.read(): |
| | autogenerated_files.append(file_path) |
| |
|
| | if autogenerated_files: |
| | total_generated_loc = sum(count_loc(f) for f in autogenerated_files) |
| | savings = total_generated_loc - modular_loc |
| | percentage = (savings / total_generated_loc) * 100 |
| | print( |
| | f"LoC: {modular_loc} (modular) vs {total_generated_loc} (generated) - saved {savings} LoC ({percentage:.1f}%)" |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | |
| | parser.add_argument( |
| | "files", |
| | nargs="*", |
| | help="A list of `modular_xxxx` files that should be converted to single model file", |
| | ) |
| | parser.add_argument( |
| | "--files-to-parse", |
| | "--files_to_parse", |
| | "--files", |
| | "-f", |
| | default=["all"], |
| | nargs="+", |
| | help="A list of `modular_xxxx` files that should be converted to single model file", |
| | ) |
| | parser.add_argument( |
| | "--num_workers", |
| | "-w", |
| | default=-1, |
| | type=int, |
| | help="The number of workers to use. Default is -1, which means the number of CPU cores.", |
| | ) |
| | parser.add_argument( |
| | "--source-library", |
| | type=str, |
| | default="transformers", |
| | help="The top-level package name (default: 'transformers')", |
| | ) |
| | args = parser.parse_args() |
| | |
| | files_to_parse = args.files if len(args.files) > 0 else args.files_to_parse |
| | num_workers = mp.cpu_count() if args.num_workers == -1 else args.num_workers |
| |
|
| | if files_to_parse == ["all"]: |
| | files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) |
| | if files_to_parse == ["examples"]: |
| | files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) |
| | else: |
| | for i, model_name in enumerate(files_to_parse): |
| | if os.sep not in model_name: |
| | full_path = os.path.join("src", "transformers", "models", model_name, f"modular_{model_name}.py") |
| | |
| | if not os.path.isfile(full_path): |
| | full_path = os.path.join("examples", "modular-transformers", f"modular_{model_name}.py") |
| | |
| | if not os.path.isfile(full_path): |
| | raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.") |
| | files_to_parse[i] = full_path |
| |
|
| | |
| | |
| | ordered_files, _ = find_priority_list(files_to_parse) |
| | if sum(len(level_files) for level_files in ordered_files) != len(files_to_parse): |
| | raise ValueError( |
| | "Some files will not be converted because they do not appear in the dependency graph." |
| | "This usually means that at least one modular file does not import any model-specific class" |
| | ) |
| |
|
| | for dependency_level_files in ordered_files: |
| | |
| | workers = min(num_workers, len(dependency_level_files)) |
| | with mp.Pool(workers) as pool: |
| | pool.map(partial(run_converter, source_library=args.source_library), dependency_level_files) |
| |
|