File size: 3,937 Bytes
c4f3819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""AST-based discovery for small Python packages."""

from __future__ import annotations

import ast
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class FunctionInfo:
    """Metadata needed to generate characterization tests."""

    module: str
    qualname: str
    name: str
    parameters: tuple[str, ...]
    type_hints: dict[str, str]
    return_hint: str | None
    docstring: str
    source: str
    file_path: str
    lineno: int

    @property
    def import_name(self) -> str:
        return self.qualname.split(".")[-1]

    @property
    def arity(self) -> int:
        return len(self.parameters)


@dataclass(frozen=True)
class Analysis:
    """Public API discovered in a target package."""

    root: Path
    package: str
    functions: tuple[FunctionInfo, ...]


def analyze(root: str | Path, package: str | None = None) -> Analysis:
    """Discover public top-level functions and public class methods."""
    root_path = Path(root).resolve()
    if not root_path.exists():
        raise FileNotFoundError(root_path)

    package_name = package or root_path.name
    functions: list[FunctionInfo] = []
    for file_path in sorted(root_path.rglob("*.py")):
        if file_path.name == "__init__.py":
            continue
        module = _module_name(root_path, package_name, file_path)
        source = file_path.read_text(encoding="utf-8")
        tree = ast.parse(source, filename=str(file_path))
        lines = source.splitlines()
        functions.extend(_top_level_functions(tree, module, file_path, lines))
        functions.extend(_class_methods(tree, module, file_path, lines))

    return Analysis(root=root_path, package=package_name, functions=tuple(functions))


def _module_name(root: Path, package: str, file_path: Path) -> str:
    relative = file_path.relative_to(root).with_suffix("")
    parts = (package, *relative.parts)
    return ".".join(parts)


def _top_level_functions(
    tree: ast.Module, module: str, file_path: Path, lines: list[str]
) -> list[FunctionInfo]:
    found: list[FunctionInfo] = []
    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and _is_public(node.name):
            found.append(_function_info(node, module, node.name, file_path, lines))
    return found


def _class_methods(
    tree: ast.Module, module: str, file_path: Path, lines: list[str]
) -> list[FunctionInfo]:
    found: list[FunctionInfo] = []
    for class_node in tree.body:
        if not isinstance(class_node, ast.ClassDef) or not _is_public(class_node.name):
            continue
        for node in class_node.body:
            if isinstance(node, ast.FunctionDef) and _is_public(node.name):
                qualname = f"{class_node.name}.{node.name}"
                found.append(_function_info(node, module, qualname, file_path, lines))
    return found


def _function_info(
    node: ast.FunctionDef, module: str, qualname: str, file_path: Path, lines: list[str]
) -> FunctionInfo:
    parameters: list[str] = []
    hints: dict[str, str] = {}
    for arg in node.args.args:
        if arg.arg in {"self", "cls"}:
            continue
        parameters.append(arg.arg)
        if arg.annotation is not None:
            hints[arg.arg] = ast.unparse(arg.annotation)

    source = _source_segment(node, lines)
    return FunctionInfo(
        module=module,
        qualname=qualname,
        name=qualname,
        parameters=tuple(parameters),
        type_hints=hints,
        return_hint=ast.unparse(node.returns) if node.returns is not None else None,
        docstring=ast.get_docstring(node) or "",
        source=source,
        file_path=str(file_path),
        lineno=node.lineno,
    )


def _source_segment(node: ast.AST, lines: list[str]) -> str:
    end = getattr(node, "end_lineno", node.lineno)
    return "\n".join(lines[node.lineno - 1 : end])


def _is_public(name: str) -> bool:
    return not name.startswith("_")