Spaces:
Running on Zero
Running on Zero
| # coding=utf-8 | |
| # Copyright 2026 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Check that arguments of ``forward()`` (for models) and ``__call__()`` (for | |
| pipelines) match the method's docstring exactly: | |
| * every signature argument has an entry in the ``Args:`` / | |
| ``Arguments:`` / ``Parameters:`` section, | |
| * every documented argument still exists in the signature | |
| (stale entries from removed/renamed args are flagged), and | |
| * when the method has a non-``None`` return annotation, the docstring has | |
| a ``Returns:`` / ``Return:`` / ``Yields:`` section. | |
| A "main" class is detected via its base classes — models inherit from | |
| ``ModelMixin`` and pipelines inherit from ``DiffusionPipeline``. Only methods | |
| defined directly on the class are checked; inherited methods are checked when | |
| the parent class is visited. | |
| Run from the repository root: | |
| python utils/check_forward_call_docstrings.py | |
| Optionally restrict to specific files: | |
| python utils/check_forward_call_docstrings.py --paths src/diffusers/models/transformers/transformer_flux.py | |
| Auto-fix stale (documented-but-removed) entries — missing entries are never | |
| auto-added (no placeholders), only stale ones are removed: | |
| python utils/check_forward_call_docstrings.py --fix | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import ast | |
| import re | |
| import sys | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| MODELS_DIR = REPO_ROOT / "src" / "diffusers" / "models" | |
| PIPELINES_DIR = REPO_ROOT / "src" / "diffusers" / "pipelines" | |
| MODEL_BASE = "ModelMixin" | |
| PIPELINE_BASE = "DiffusionPipeline" | |
| SECTION_HEADERS = { | |
| "Args:", | |
| "Arguments:", | |
| "Parameters:", | |
| "Returns:", | |
| "Return:", | |
| "Yields:", | |
| "Raises:", | |
| "Examples:", | |
| "Example:", | |
| "Note:", | |
| "Notes:", | |
| "References:", | |
| "See Also:", | |
| } | |
| # `name (...)` or `name:` at the start of a (stripped) line. | |
| _ARG_HEADER_RE = re.compile(r"^([A-Za-z_]\w*)\s*[(:]") | |
| # Pairs of (class_name, method_name) whose missing-arg errors should be | |
| # suppressed. Use sparingly — prefer fixing the docstring. | |
| IGNORE: set[tuple[str, str]] = set() | |
| def _base_class_names(class_def: ast.ClassDef) -> set[str]: | |
| """Return the textual names of base classes (best-effort).""" | |
| names: set[str] = set() | |
| for base in class_def.bases: | |
| if isinstance(base, ast.Name): | |
| names.add(base.id) | |
| elif isinstance(base, ast.Attribute): | |
| names.add(base.attr) | |
| return names | |
| def _find_method(class_def: ast.ClassDef, method_name: str) -> ast.FunctionDef | ast.AsyncFunctionDef | None: | |
| for node in class_def.body: | |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == method_name: | |
| return node | |
| return None | |
| def _docstring_node(func: ast.FunctionDef | ast.AsyncFunctionDef) -> ast.Expr | None: | |
| if ( | |
| func.body | |
| and isinstance(func.body[0], ast.Expr) | |
| and isinstance(func.body[0].value, ast.Constant) | |
| and isinstance(func.body[0].value.value, str) | |
| ): | |
| return func.body[0] | |
| return None | |
| def _signature_arg_names(func: ast.FunctionDef | ast.AsyncFunctionDef) -> list[str]: | |
| args = func.args | |
| collected: list[str] = [] | |
| for a in (*args.posonlyargs, *args.args, *args.kwonlyargs): | |
| if a.arg == "self" or a.arg == "cls": | |
| continue | |
| collected.append(a.arg) | |
| return collected | |
| def _has_meaningful_return(func: ast.FunctionDef | ast.AsyncFunctionDef) -> bool: | |
| """True iff the method has a return annotation other than ``None`` or ``NoReturn``.""" | |
| ret = func.returns | |
| if ret is None: # no annotation at all | |
| return False | |
| if isinstance(ret, ast.Constant) and ret.value is None: # `-> None` | |
| return False | |
| # `-> NoReturn` or `-> typing.NoReturn` | |
| if isinstance(ret, ast.Name) and ret.id == "NoReturn": | |
| return False | |
| if isinstance(ret, ast.Attribute) and ret.attr == "NoReturn": | |
| return False | |
| return True | |
| def _has_returns_section(docstring: str | None) -> bool: | |
| if not docstring: | |
| return False | |
| for line in docstring.splitlines(): | |
| if line.strip() in {"Returns:", "Return:", "Yields:", "Yield:"}: | |
| return True | |
| return False | |
| def _extract_documented_args(docstring: str | None) -> set[str]: | |
| """Extract argument names listed in an Args/Arguments/Parameters section. | |
| Assumes the docstring has been cleaned (``inspect.cleandoc`` / ``ast.get_docstring``). | |
| The section ends at the next blank-line-followed-by-section-header or at the | |
| end of the docstring. | |
| """ | |
| if not docstring: | |
| return set() | |
| lines = docstring.splitlines() | |
| # Locate the Args/Arguments/Parameters header. | |
| start = None | |
| header_indent = 0 | |
| for i, line in enumerate(lines): | |
| stripped = line.strip() | |
| if stripped in {"Args:", "Arguments:", "Parameters:"}: | |
| start = i + 1 | |
| header_indent = len(line) - len(line.lstrip()) | |
| break | |
| if start is None: | |
| return set() | |
| # First non-empty line after the header sets the per-entry indent level. | |
| entry_indent: int | None = None | |
| documented: set[str] = set() | |
| for line in lines[start:]: | |
| stripped = line.strip() | |
| if not stripped: | |
| continue | |
| indent = len(line) - len(line.lstrip()) | |
| # A new section at the same (or shallower) indent ends the args block. | |
| if indent <= header_indent and stripped in SECTION_HEADERS: | |
| break | |
| if entry_indent is None: | |
| entry_indent = indent | |
| # Only lines at the entry indent are candidate arg headers; deeper | |
| # indents are descriptions/continuations. | |
| if indent != entry_indent: | |
| continue | |
| match = _ARG_HEADER_RE.match(stripped) | |
| if match: | |
| documented.add(match.group(1)) | |
| return documented | |
| def check_file(path: Path, kind: str) -> list[str]: | |
| """Return a list of human-readable error strings for ``path``.""" | |
| method_name = "forward" if kind == "model" else "__call__" | |
| base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE | |
| try: | |
| tree = ast.parse(path.read_text(encoding="utf-8")) | |
| except (SyntaxError, UnicodeDecodeError): | |
| return [] | |
| errors: list[str] = [] | |
| rel = path.relative_to(REPO_ROOT) | |
| for node in ast.walk(tree): | |
| if not isinstance(node, ast.ClassDef): | |
| continue | |
| if base_class not in _base_class_names(node): | |
| continue | |
| if (node.name, method_name) in IGNORE: | |
| continue | |
| method = _find_method(node, method_name) | |
| if method is None: | |
| continue | |
| sig_args = _signature_arg_names(method) | |
| sig_set = set(sig_args) | |
| docstring_text = ast.get_docstring(method) | |
| documented = _extract_documented_args(docstring_text) | |
| missing = [a for a in sig_args if a not in documented] | |
| stale = sorted(documented - sig_set) | |
| if missing: | |
| errors.append( | |
| f"{rel}:{method.lineno}: {node.name}.{method_name} is missing " | |
| f"docstring entries for: {', '.join(missing)}" | |
| ) | |
| if stale: | |
| errors.append( | |
| f"{rel}:{method.lineno}: {node.name}.{method_name} documents " | |
| f"argument(s) not in the signature: {', '.join(stale)}" | |
| ) | |
| if _has_meaningful_return(method) and not _has_returns_section(docstring_text): | |
| return_repr = ast.unparse(method.returns) | |
| ds = _docstring_node(method) | |
| if ds is None: | |
| where = " (method has no docstring)" | |
| else: | |
| where = f' (add it just above the closing """ on line {ds.end_lineno})' | |
| errors.append( | |
| f"{rel}:{method.lineno}: {node.name}.{method_name} returns " | |
| f"`{return_repr}` but the docstring has no Returns: section{where}" | |
| ) | |
| return errors | |
| def fix_file(path: Path, kind: str) -> list[str]: | |
| """Remove stale arg entries (documented but not in signature) in-place. | |
| Missing-in-signature → docstring entries are NOT added (no placeholders). | |
| Returns a list of ``"ClassName.method: removed name1, name2"`` strings | |
| describing what was removed. | |
| """ | |
| method_name = "forward" if kind == "model" else "__call__" | |
| base_class = MODEL_BASE if kind == "model" else PIPELINE_BASE | |
| source = path.read_text(encoding="utf-8") | |
| try: | |
| tree = ast.parse(source) | |
| except (SyntaxError, UnicodeDecodeError): | |
| return [] | |
| lines = source.splitlines(keepends=True) | |
| # (start_idx, end_idx_exclusive) ranges of lines to drop. | |
| deletions: list[tuple[int, int]] = [] | |
| summaries: list[str] = [] | |
| for node in ast.walk(tree): | |
| if not isinstance(node, ast.ClassDef): | |
| continue | |
| if base_class not in _base_class_names(node): | |
| continue | |
| method = _find_method(node, method_name) | |
| if method is None: | |
| continue | |
| # Method must start with a string docstring expression. | |
| if not ( | |
| method.body | |
| and isinstance(method.body[0], ast.Expr) | |
| and isinstance(method.body[0].value, ast.Constant) | |
| and isinstance(method.body[0].value.value, str) | |
| ): | |
| continue | |
| sig_set = set(_signature_arg_names(method)) | |
| documented = _extract_documented_args(ast.get_docstring(method)) | |
| stale = documented - sig_set | |
| if not stale: | |
| continue | |
| docstring_expr = method.body[0] | |
| doc_start = docstring_expr.lineno - 1 # 0-indexed | |
| doc_end = docstring_expr.end_lineno - 1 # 0-indexed, inclusive | |
| # Locate the Args/Arguments/Parameters header in raw source. | |
| args_idx: int | None = None | |
| header_indent = 0 | |
| for i in range(doc_start, doc_end + 1): | |
| stripped = lines[i].strip() | |
| if stripped in {"Args:", "Arguments:", "Parameters:"}: | |
| args_idx = i | |
| header_indent = len(lines[i]) - len(lines[i].lstrip()) | |
| break | |
| if args_idx is None: | |
| continue | |
| # First non-empty line after the header sets the per-entry indent. | |
| entry_indent: int | None = None | |
| for i in range(args_idx + 1, doc_end + 1): | |
| stripped = lines[i].strip() | |
| if not stripped: | |
| continue | |
| entry_indent = len(lines[i]) - len(lines[i].lstrip()) | |
| break | |
| if entry_indent is None or entry_indent <= header_indent: | |
| continue | |
| # Walk entries; each entry spans from its header line up to (but not | |
| # including) the next entry header / section header / end of docstring. | |
| current_name: str | None = None | |
| current_start: int = -1 | |
| end_of_args: int | None = None | |
| for i in range(args_idx + 1, doc_end + 1): | |
| line = lines[i] | |
| stripped = line.strip() | |
| if not stripped: | |
| continue | |
| indent = len(line) - len(line.lstrip()) | |
| if indent <= header_indent and stripped in SECTION_HEADERS: | |
| end_of_args = i | |
| break | |
| if indent == entry_indent: | |
| m = _ARG_HEADER_RE.match(stripped) | |
| if m: | |
| if current_name in stale: | |
| deletions.append((current_start, i)) | |
| current_name = m.group(1) | |
| current_start = i | |
| if current_name in stale: | |
| end = end_of_args if end_of_args is not None else doc_end | |
| # Trailing blank lines belong to inter-section spacing (or the | |
| # blank line before the closing """), not to this entry. | |
| while end > current_start + 1 and not lines[end - 1].strip(): | |
| end -= 1 | |
| deletions.append((current_start, end)) | |
| summaries.append(f"{node.name}.{method_name}: removed {', '.join(sorted(stale))}") | |
| if not deletions: | |
| return [] | |
| deletions.sort() | |
| new_lines = list(lines) | |
| for start, end in reversed(deletions): | |
| del new_lines[start:end] | |
| path.write_text("".join(new_lines), encoding="utf-8") | |
| return summaries | |
| def _kind_for_path(path: Path) -> str | None: | |
| parts = path.resolve().parts | |
| if "pipelines" in parts: | |
| return "pipeline" | |
| if "models" in parts: | |
| return "model" | |
| return None | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument( | |
| "--paths", | |
| nargs="+", | |
| help="Specific files to check (defaults to all of src/diffusers/{models,pipelines}).", | |
| ) | |
| parser.add_argument( | |
| "--limit", | |
| type=int, | |
| default=None, | |
| help=( | |
| "Debug helper: when --paths is not given, only check the first N files " | |
| "(in sorted order) from each of models/ and pipelines/." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--fix", | |
| action="store_true", | |
| help=( | |
| "Remove stale (documented-but-not-in-signature) argument entries from " | |
| "docstrings in-place. Missing-in-docstring entries are NOT auto-added " | |
| "(no placeholders) and will still be reported." | |
| ), | |
| ) | |
| args = parser.parse_args() | |
| targets: list[tuple[Path, str]] = [] | |
| if args.paths: | |
| for raw in args.paths: | |
| p = Path(raw).resolve() | |
| kind = _kind_for_path(p) | |
| if kind is None: | |
| print(f"Skipping {raw}: not under models/ or pipelines/.", file=sys.stderr) | |
| continue | |
| targets.append((p, kind)) | |
| else: | |
| model_files = sorted(MODELS_DIR.rglob("*.py")) | |
| pipeline_files = sorted(PIPELINES_DIR.rglob("*.py")) | |
| if args.limit is not None: | |
| if args.limit < 0: | |
| parser.error("--limit must be non-negative") | |
| model_files = model_files[: args.limit] | |
| pipeline_files = pipeline_files[: args.limit] | |
| print( | |
| f"--limit {args.limit}: checking {len(model_files)} model file(s) " | |
| f"and {len(pipeline_files)} pipeline file(s).", | |
| file=sys.stderr, | |
| ) | |
| for p in model_files: | |
| targets.append((p, "model")) | |
| for p in pipeline_files: | |
| targets.append((p, "pipeline")) | |
| if args.fix: | |
| fix_summaries: list[str] = [] | |
| for path, kind in targets: | |
| for summary in fix_file(path, kind): | |
| fix_summaries.append(f"{path.relative_to(REPO_ROOT)}: {summary}") | |
| if fix_summaries: | |
| print("Removed stale docstring entries:") | |
| print("\n".join(f" {s}" for s in fix_summaries)) | |
| else: | |
| print("No stale docstring entries to remove.") | |
| all_errors: list[str] = [] | |
| for path, kind in targets: | |
| all_errors.extend(check_file(path, kind)) | |
| if all_errors: | |
| print("\n".join(all_errors)) | |
| print( | |
| f"\nFound {len(all_errors)} docstring/signature mismatch(es).", | |
| file=sys.stderr, | |
| ) | |
| if not args.fix and any("documents argument(s) not in the signature" in e for e in all_errors): | |
| print( | |
| "Hint: run `python utils/check_forward_call_docstrings.py --fix` " | |
| "to remove the stale argument entries flagged above. " | |
| "(Missing-in-docstring entries must be added manually — the tool " | |
| "never inserts placeholders.)", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| print("All forward/__call__ arguments are documented.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |