File size: 6,103 Bytes
a9bd396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | #!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that ensures that modeling (and modular) files respect some important conventions we have in Transformers.
"""
import ast
import sys
from pathlib import Path
from rich import print
MODELS_ROOT = Path("src/transformers/models")
MODELING_PATTERNS = ("modeling_*.py", "modular_*.py")
def iter_modeling_files():
for pattern in MODELING_PATTERNS:
yield from MODELS_ROOT.rglob(pattern)
def colored_error_message(file_path: str, line_number: int, message: str) -> str:
return f"[bold red]{file_path}[/bold red]:[bold yellow]L{line_number}[/bold yellow]: {message}"
def full_name(node: ast.AST):
"""
Return full dotted name from an Attribute or Name node.
"""
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return full_name(node.value) + "." + node.attr
else:
raise ValueError("Not a Name or Attribute node")
def check_init_weights(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
"""
Check that `_init_weights` correctly use `init.(...)` patterns to init the weights in-place. This is very important,
as we rely on the internal flag set on the parameters themselves to check if they need to be re-init or not.
"""
if isinstance(node, ast.FunctionDef) and node.name == "_init_weights":
args = node.args.args
if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module":
return violations
for sub_node in ast.walk(node):
if isinstance(sub_node, ast.Call) and isinstance(sub_node.func, ast.Attribute):
is_inplace_ops = sub_node.func.attr.endswith("_")
# We allow in-place ops on tensors that are not part of the module itself (see e.g. modeling_qwen3_next.py L997)
is_on_module_weight = isinstance(
sub_node.func.value, (ast.Name, ast.Attribute)
) and "module." in full_name(sub_node.func.value)
if is_inplace_ops and is_on_module_weight:
error_msg = (
"`_init_weights(self, module)` uses an in-place operation on a module's weight. Please use the "
"`init` functions primitives instead, usually imported as `from ... import initialization as init`"
)
violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))
return violations
def is_self_method_call(node: ast.AST, method: str) -> bool:
"""Check if `node` is a method call on `self`, such as `self.method(...)`"""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "self"
and node.func.attr == method
)
def is_super_method_call(node: ast.AST, method: str) -> bool:
"""Check if `node` is a call to `super().method(...)`"""
return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Call)
and isinstance(node.func.value.func, ast.Name)
and node.func.value.func.id == "super"
and node.func.attr == method
)
def check_post_init(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
"""
Check that `self.post_init()` is correctly called at the end of `__init__` for all `PreTrainedModel`s. This is
very important as we need to do some processing there.
"""
# Check if it's a PreTrainedModel class definition
if isinstance(node, ast.ClassDef) and any(full_name(parent).endswith("PreTrainedModel") for parent in node.bases):
for sub_node in node.body:
# Check that we are in __init__
if isinstance(sub_node, ast.FunctionDef) and sub_node.name == "__init__":
for statement in ast.walk(sub_node):
# This means it's correctly called verbatim
if is_self_method_call(statement, method="post_init"):
break
# This means `super().__init__` is called in a modular, so it is already called in the parent
elif "modular_" in str(file_path) and is_super_method_call(statement, method="__init__"):
break
# If we did not break, `post_init` was never called
else:
error_msg = f"`__init__` of {node.name} does not call `self.post_init`"
violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))
break
return violations
def main():
violations: list[str] = []
for file_path in iter_modeling_files():
try:
text = file_path.read_text(encoding="utf-8")
tree = ast.parse(text, filename=str(file_path))
except Exception as exc:
violations.append(f"{file_path}: failed to parse ({exc}).")
continue
for node in ast.walk(tree):
violations = check_init_weights(node, violations, file_path)
violations = check_post_init(node, violations, file_path)
if len(violations) > 0:
violations = sorted(violations)
print("\n".join(violations), file=sys.stderr)
raise ValueError("Some errors in modelings. Check the above message")
if __name__ == "__main__":
main()
|