File size: 7,273 Bytes
e48cd48
 
858a56f
e48cd48
b394762
 
ecf8f3a
e48cd48
 
 
 
 
ecf8f3a
b394762
 
ecf8f3a
 
 
 
e48cd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858a56f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e48cd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b394762
e48cd48
b394762
 
 
 
858a56f
 
 
 
 
 
 
 
 
b394762
 
 
e48cd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858a56f
e48cd48
858a56f
 
 
 
 
 
 
 
e48cd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858a56f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
from __future__ import annotations

import ast
import inspect
import sys
import types
from typing import Any, Annotated, get_args, get_origin, get_type_hints, Union


def _typename(tp: Any) -> str:
    """Return a readable type name from a type or annotation."""
    try:
        # Unwrap Optional[T] -> T
        origin = get_origin(tp)
        if origin is Union or (sys.version_info >= (3, 10) and origin is types.UnionType):
            args = [a for a in get_args(tp) if a is not type(None)]
            if len(args) == 1:
                return _typename(args[0])

        if hasattr(tp, "__name__"):
            return tp.__name__  # e.g. int, str
        if getattr(tp, "__module__", None) and getattr(tp, "__qualname__", None):
            return f"{tp.__module__}.{tp.__qualname__}"
        return str(tp).replace("typing.", "")
    except Exception:
        return str(tp)


def _extract_base_and_meta(annotation: Any) -> tuple[Any, str | None]:
    """Given an annotation, return (base_type, first string metadata) if Annotated, else (annotation, None)."""
    try:
        if get_origin(annotation) is Annotated:
            args = get_args(annotation)
            base = args[0] if args else annotation
            # Grab the first string metadata if present
            for meta in args[1:]:
                if isinstance(meta, str):
                    return base, meta
            return base, None
        return annotation, None
    except Exception:
        return annotation, None


def _parse_annotated_string(annot_str: str) -> tuple[str, str | None]:
    """Fallback: parse 'Annotated[Type, "desc"]' string using AST."""
    try:
        expr = ast.parse(annot_str, mode='eval').body
        if isinstance(expr, ast.Subscript):
            val = expr.value
            is_annotated = False
            if isinstance(val, ast.Name) and val.id == 'Annotated':
                is_annotated = True
            elif isinstance(val, ast.Attribute) and val.attr == 'Annotated':
                is_annotated = True
            
            if is_annotated:
                sl = expr.slice
                # In 3.9+, slice is the node. In <3.9, it might be Index/ExtSlice.
                if isinstance(sl, ast.Tuple):
                    elts = sl.elts
                    if len(elts) >= 2:
                        # elts[0] is type, elts[1] is metadata
                        meta_node = elts[1]
                        desc = None
                        if isinstance(meta_node, ast.Constant) and isinstance(meta_node.value, str):
                            desc = meta_node.value
                        elif isinstance(meta_node, ast.Str):
                            desc = meta_node.s
                        
                        if desc:
                            if hasattr(ast, 'unparse'):
                                type_str = ast.unparse(elts[0])
                            else:
                                type_str = "Any"
                            return type_str, desc
    except Exception:
        pass
    return annot_str, None


def autodoc(summary: str | None = None, returns: str | None = None, *, force: bool = False):
    """
    Decorator that auto-generates a concise Google-style docstring from a function's
    type hints and Annotated metadata. Useful for Gradio MCP where docstrings are
    used for tool descriptions and parameter docs.

    Args:
        summary: Optional one-line summary for the function. If not provided,
            will generate a simple sentence from the function name.
        returns: Optional return value description. If not provided, only the
            return type will be listed (if available).
        force: When True, overwrite an existing docstring. Default False.

    Returns:
        The original function with its __doc__ populated (unless skipped).
    """

    def decorator(func):
        # Skip if docstring already present and not forcing
        if not force and func.__doc__ and func.__doc__.strip():
            return func

        try:
            # include_extras=True to retain Annotated metadata
            hints = get_type_hints(func, include_extras=True, globalns=getattr(func, "__globals__", None))
        except Exception:
            # Fallback: try to evaluate annotations manually if they are strings
            hints = {}
            sig = inspect.signature(func)
            for name, param in sig.parameters.items():
                if isinstance(param.annotation, str):
                    try:
                        # Ensure typing is available in eval context
                        globs = getattr(func, "__globals__", {}).copy()
                        import typing
                        globs['typing'] = typing
                        for t in ['Annotated', 'Literal', 'Optional', 'Union', 'List', 'Dict', 'Any']:
                            if t not in globs:
                                globs[t] = getattr(typing, t)
                        
                        val = eval(param.annotation, globs)
                        hints[name] = val
                    except Exception:
                        pass

        sig = inspect.signature(func)

        lines: list[str] = []
        # Summary line
        if summary and summary.strip():
            lines.append(summary.strip())
        else:
            pretty = func.__name__.replace("_", " ").strip().capitalize()
            if not pretty.endswith("."):
                pretty += "."
            lines.append(pretty)

        # Args section
        if sig.parameters:
            lines.append("")
            lines.append("Args:")
            for name, param in sig.parameters.items():
                if name == "self":
                    continue
                annot = hints.get(name, param.annotation)
                
                base, meta = _extract_base_and_meta(annot)
                
                # If meta is missing and annot is a string, try AST fallback
                if meta is None and isinstance(annot, str):
                    base_str, meta_str = _parse_annotated_string(annot)
                    if meta_str:
                        base = base_str
                        meta = meta_str

                tname = _typename(base) if base is not inspect._empty else None
                desc = meta or ""
                if tname and tname != str(inspect._empty):
                    lines.append(f"    {name} ({tname}): {desc}".rstrip())
                else:
                    lines.append(f"    {name}: {desc}".rstrip())

        # Returns section
        ret_hint = hints.get("return", sig.return_annotation)
        if returns or (ret_hint and ret_hint is not inspect.Signature.empty):
            lines.append("")
            lines.append("Returns:")
            if returns:
                lines.append(f"    {returns}")
            else:
                base, meta = _extract_base_and_meta(ret_hint)
                rtype = _typename(base)
                if meta:
                    lines.append(f"    {rtype}: {meta}")
                else:
                    lines.append(f"    {rtype}")

        func.__doc__ = "\n".join(lines).strip() + "\n"
        return func

    return decorator


__all__ = ["autodoc"]