| """Utilities related to attribute docstring extraction.""" |
|
|
| from __future__ import annotations |
|
|
| import ast |
| import inspect |
| import sys |
| import textwrap |
| from typing import Any |
|
|
|
|
| class DocstringVisitor(ast.NodeVisitor): |
| def __init__(self) -> None: |
| super().__init__() |
|
|
| self.target: str | None = None |
| self.attrs: dict[str, str] = {} |
| self.previous_node_type: type[ast.AST] | None = None |
|
|
| def visit(self, node: ast.AST) -> Any: |
| node_result = super().visit(node) |
| self.previous_node_type = type(node) |
| return node_result |
|
|
| def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: |
| if isinstance(node.target, ast.Name): |
| self.target = node.target.id |
|
|
| def visit_Expr(self, node: ast.Expr) -> Any: |
| if ( |
| isinstance(node.value, ast.Constant) |
| and isinstance(node.value.value, str) |
| and self.previous_node_type is ast.AnnAssign |
| ): |
| docstring = inspect.cleandoc(node.value.value) |
| if self.target: |
| self.attrs[self.target] = docstring |
| self.target = None |
|
|
|
|
| def _dedent_source_lines(source: list[str]) -> str: |
| |
| dedent_source = textwrap.dedent(''.join(source)) |
| if dedent_source.startswith((' ', '\t')): |
| |
| |
| |
| dedent_source = f'def dedent_workaround():\n{dedent_source}' |
| return dedent_source |
|
|
|
|
| def _extract_source_from_frame(cls: type[Any]) -> list[str] | None: |
| frame = inspect.currentframe() |
|
|
| while frame: |
| if inspect.getmodule(frame) is inspect.getmodule(cls): |
| lnum = frame.f_lineno |
| try: |
| lines, _ = inspect.findsource(frame) |
| except OSError: |
| |
| |
| pass |
| else: |
| block_lines = inspect.getblock(lines[lnum - 1 :]) |
| dedent_source = _dedent_source_lines(block_lines) |
| try: |
| block_tree = ast.parse(dedent_source) |
| except SyntaxError: |
| pass |
| else: |
| stmt = block_tree.body[0] |
| if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround': |
| |
| stmt = stmt.body[0] |
| if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__: |
| return block_lines |
|
|
| frame = frame.f_back |
|
|
|
|
| def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]: |
| """Map model attributes and their corresponding docstring. |
| |
| Args: |
| cls: The class of the Pydantic model to inspect. |
| use_inspect: Whether to skip usage of frames to find the object and use |
| the `inspect` module instead. |
| |
| Returns: |
| A mapping containing attribute names and their corresponding docstring. |
| """ |
| if use_inspect or sys.version_info >= (3, 13): |
| |
| |
| |
| |
| try: |
| source, _ = inspect.getsourcelines(cls) |
| except OSError: |
| return {} |
| else: |
| |
| source = _extract_source_from_frame(cls) |
|
|
| if not source: |
| return {} |
|
|
| dedent_source = _dedent_source_lines(source) |
|
|
| visitor = DocstringVisitor() |
| visitor.visit(ast.parse(dedent_source)) |
| return visitor.attrs |
|
|