| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Utility that ensures that modeling (and modular) files respect some important conventions we have in Transformers. |
| | """ |
| |
|
| | import ast |
| | import sys |
| | from pathlib import Path |
| |
|
| | from rich import print |
| |
|
| |
|
| | MODELS_ROOT = Path("src/transformers/models") |
| | MODELING_PATTERNS = ("modeling_*.py", "modular_*.py") |
| |
|
| |
|
| | def iter_modeling_files(): |
| | for pattern in MODELING_PATTERNS: |
| | yield from MODELS_ROOT.rglob(pattern) |
| |
|
| |
|
| | def colored_error_message(file_path: str, line_number: int, message: str) -> str: |
| | return f"[bold red]{file_path}[/bold red]:[bold yellow]L{line_number}[/bold yellow]: {message}" |
| |
|
| |
|
| | def full_name(node: ast.AST): |
| | """ |
| | Return full dotted name from an Attribute or Name node. |
| | """ |
| | if isinstance(node, ast.Name): |
| | return node.id |
| | elif isinstance(node, ast.Attribute): |
| | return full_name(node.value) + "." + node.attr |
| | else: |
| | raise ValueError("Not a Name or Attribute node") |
| |
|
| |
|
| | def check_init_weights(node: ast.AST, violations: list[str], file_path: str) -> list[str]: |
| | """ |
| | Check that `_init_weights` correctly use `init.(...)` patterns to init the weights in-place. This is very important, |
| | as we rely on the internal flag set on the parameters themselves to check if they need to be re-init or not. |
| | """ |
| | if isinstance(node, ast.FunctionDef) and node.name == "_init_weights": |
| | args = node.args.args |
| | if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module": |
| | return violations |
| |
|
| | for sub_node in ast.walk(node): |
| | if isinstance(sub_node, ast.Call) and isinstance(sub_node.func, ast.Attribute): |
| | is_inplace_ops = sub_node.func.attr.endswith("_") |
| | |
| | is_on_module_weight = isinstance( |
| | sub_node.func.value, (ast.Name, ast.Attribute) |
| | ) and "module." in full_name(sub_node.func.value) |
| | if is_inplace_ops and is_on_module_weight: |
| | error_msg = ( |
| | "`_init_weights(self, module)` uses an in-place operation on a module's weight. Please use the " |
| | "`init` functions primitives instead, usually imported as `from ... import initialization as init`" |
| | ) |
| | violations.append(colored_error_message(file_path, sub_node.lineno, error_msg)) |
| |
|
| | return violations |
| |
|
| |
|
| | def is_self_method_call(node: ast.AST, method: str) -> bool: |
| | """Check if `node` is a method call on `self`, such as `self.method(...)`""" |
| | return ( |
| | isinstance(node, ast.Call) |
| | and isinstance(node.func, ast.Attribute) |
| | and isinstance(node.func.value, ast.Name) |
| | and node.func.value.id == "self" |
| | and node.func.attr == method |
| | ) |
| |
|
| |
|
| | def is_super_method_call(node: ast.AST, method: str) -> bool: |
| | """Check if `node` is a call to `super().method(...)`""" |
| | return ( |
| | isinstance(node, ast.Call) |
| | and isinstance(node.func, ast.Attribute) |
| | and isinstance(node.func.value, ast.Call) |
| | and isinstance(node.func.value.func, ast.Name) |
| | and node.func.value.func.id == "super" |
| | and node.func.attr == method |
| | ) |
| |
|
| |
|
| | def check_post_init(node: ast.AST, violations: list[str], file_path: str) -> list[str]: |
| | """ |
| | Check that `self.post_init()` is correctly called at the end of `__init__` for all `PreTrainedModel`s. This is |
| | very important as we need to do some processing there. |
| | """ |
| | |
| | if isinstance(node, ast.ClassDef) and any(full_name(parent).endswith("PreTrainedModel") for parent in node.bases): |
| | for sub_node in node.body: |
| | |
| | if isinstance(sub_node, ast.FunctionDef) and sub_node.name == "__init__": |
| | for statement in ast.walk(sub_node): |
| | |
| | if is_self_method_call(statement, method="post_init"): |
| | break |
| | |
| | elif "modular_" in str(file_path) and is_super_method_call(statement, method="__init__"): |
| | break |
| | |
| | else: |
| | error_msg = f"`__init__` of {node.name} does not call `self.post_init`" |
| | violations.append(colored_error_message(file_path, sub_node.lineno, error_msg)) |
| | break |
| |
|
| | return violations |
| |
|
| |
|
| | def main(): |
| | violations: list[str] = [] |
| |
|
| | for file_path in iter_modeling_files(): |
| | try: |
| | text = file_path.read_text(encoding="utf-8") |
| | tree = ast.parse(text, filename=str(file_path)) |
| | except Exception as exc: |
| | violations.append(f"{file_path}: failed to parse ({exc}).") |
| | continue |
| |
|
| | for node in ast.walk(tree): |
| | violations = check_init_weights(node, violations, file_path) |
| | violations = check_post_init(node, violations, file_path) |
| |
|
| | if len(violations) > 0: |
| | violations = sorted(violations) |
| | print("\n".join(violations), file=sys.stderr) |
| | raise ValueError("Some errors in modelings. Check the above message") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|