File size: 6,103 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that ensures that modeling (and modular) files respect some important conventions we have in Transformers.
"""

import ast
import sys
from pathlib import Path

from rich import print


MODELS_ROOT = Path("src/transformers/models")
MODELING_PATTERNS = ("modeling_*.py", "modular_*.py")


def iter_modeling_files():
    for pattern in MODELING_PATTERNS:
        yield from MODELS_ROOT.rglob(pattern)


def colored_error_message(file_path: str, line_number: int, message: str) -> str:
    return f"[bold red]{file_path}[/bold red]:[bold yellow]L{line_number}[/bold yellow]: {message}"


def full_name(node: ast.AST):
    """
    Return full dotted name from an Attribute or Name node.
    """
    if isinstance(node, ast.Name):
        return node.id
    elif isinstance(node, ast.Attribute):
        return full_name(node.value) + "." + node.attr
    else:
        raise ValueError("Not a Name or Attribute node")


def check_init_weights(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
    """
    Check that `_init_weights` correctly use `init.(...)` patterns to init the weights in-place. This is very important,
    as we rely on the internal flag set on the parameters themselves to check if they need to be re-init or not.
    """
    if isinstance(node, ast.FunctionDef) and node.name == "_init_weights":
        args = node.args.args
        if len(args) < 2 or getattr(args[0], "arg", None) != "self" or getattr(args[1], "arg", None) != "module":
            return violations

        for sub_node in ast.walk(node):
            if isinstance(sub_node, ast.Call) and isinstance(sub_node.func, ast.Attribute):
                is_inplace_ops = sub_node.func.attr.endswith("_")
                # We allow in-place ops on tensors that are not part of the module itself (see e.g. modeling_qwen3_next.py L997)
                is_on_module_weight = isinstance(
                    sub_node.func.value, (ast.Name, ast.Attribute)
                ) and "module." in full_name(sub_node.func.value)
                if is_inplace_ops and is_on_module_weight:
                    error_msg = (
                        "`_init_weights(self, module)` uses an in-place operation on a module's weight. Please use the "
                        "`init` functions primitives instead, usually imported as `from ... import initialization as init`"
                    )
                    violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))

    return violations


def is_self_method_call(node: ast.AST, method: str) -> bool:
    """Check if `node` is a method call on `self`, such as `self.method(...)`"""
    return (
        isinstance(node, ast.Call)
        and isinstance(node.func, ast.Attribute)
        and isinstance(node.func.value, ast.Name)
        and node.func.value.id == "self"
        and node.func.attr == method
    )


def is_super_method_call(node: ast.AST, method: str) -> bool:
    """Check if `node` is a call to `super().method(...)`"""
    return (
        isinstance(node, ast.Call)
        and isinstance(node.func, ast.Attribute)
        and isinstance(node.func.value, ast.Call)
        and isinstance(node.func.value.func, ast.Name)
        and node.func.value.func.id == "super"
        and node.func.attr == method
    )


def check_post_init(node: ast.AST, violations: list[str], file_path: str) -> list[str]:
    """
    Check that `self.post_init()` is correctly called at the end of `__init__` for all `PreTrainedModel`s. This is
    very important as we need to do some processing there.
    """
    # Check if it's a PreTrainedModel class definition
    if isinstance(node, ast.ClassDef) and any(full_name(parent).endswith("PreTrainedModel") for parent in node.bases):
        for sub_node in node.body:
            # Check that we are in __init__
            if isinstance(sub_node, ast.FunctionDef) and sub_node.name == "__init__":
                for statement in ast.walk(sub_node):
                    # This means it's correctly called verbatim
                    if is_self_method_call(statement, method="post_init"):
                        break
                    # This means `super().__init__` is called in a modular, so it is already called in the parent
                    elif "modular_" in str(file_path) and is_super_method_call(statement, method="__init__"):
                        break
                # If we did not break, `post_init` was never called
                else:
                    error_msg = f"`__init__` of {node.name} does not call `self.post_init`"
                    violations.append(colored_error_message(file_path, sub_node.lineno, error_msg))
                break

    return violations


def main():
    violations: list[str] = []

    for file_path in iter_modeling_files():
        try:
            text = file_path.read_text(encoding="utf-8")
            tree = ast.parse(text, filename=str(file_path))
        except Exception as exc:
            violations.append(f"{file_path}: failed to parse ({exc}).")
            continue

        for node in ast.walk(tree):
            violations = check_init_weights(node, violations, file_path)
            violations = check_post_init(node, violations, file_path)

    if len(violations) > 0:
        violations = sorted(violations)
        print("\n".join(violations), file=sys.stderr)
        raise ValueError("Some errors in modelings. Check the above message")


if __name__ == "__main__":
    main()