File size: 5,339 Bytes
d14041c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Generate Typer CLI options from Pydantic models."""

import functools
import inspect
from pathlib import Path
from typing import Any, Callable, Optional, get_args, get_origin

import typer
from pydantic import BaseModel


def _python_name_to_cli_flag(name: str) -> str:
    """Convert python_name to --cli-flag."""
    return "--" + name.replace("_", "-")


def _unwrap_optional(annotation: Any) -> Any:
    """Unwrap Optional[X] to X."""
    origin = get_origin(annotation)
    if origin is not None:
        args = get_args(annotation)
        if type(None) in args:
            non_none = [a for a in args if a is not type(None)]
            if non_none:
                return non_none[0]
    return annotation


def _is_bool_field(annotation: Any) -> bool:
    """Check if field is a boolean (including Optional[bool])."""
    return _unwrap_optional(annotation) is bool


def _is_list_type(annotation: Any) -> bool:
    """Check if type is a List."""
    return get_origin(annotation) is list


def _get_python_type(annotation: Any) -> type:
    """Get the Python type for annotation."""
    unwrapped = _unwrap_optional(annotation)
    if unwrapped in (str, int, float, bool, Path):
        return unwrapped
    return str


def _collect_config_fields(config_class: type[BaseModel]) -> list[tuple[str, Any]]:
    """
    Collect all fields from a config class, flattening nested models. Returns list of
    (name, field_info) tuples. Raises ValueError on duplicate field names.
    """
    fields = []
    seen_names: set[str] = set()

    for name, field_info in config_class.model_fields.items():
        annotation = field_info.annotation
        # Skip nested models - recurse into them
        if isinstance(annotation, type) and issubclass(annotation, BaseModel):
            for nested_name, nested_field in annotation.model_fields.items():
                if nested_name in seen_names:
                    raise ValueError(f"Duplicate field name '{nested_name}' in config")
                seen_names.add(nested_name)
                fields.append((nested_name, nested_field))
        else:
            if name in seen_names:
                raise ValueError(f"Duplicate field name '{name}' in config")
            seen_names.add(name)
            fields.append((name, field_info))
    return fields


def add_options_from_config(config_class: type[BaseModel]) -> Callable:
    """
    Decorator that adds CLI options for all fields in a Pydantic config model.

    The decorated function should declare a `config_overrides: dict = None` parameter
    which will receive a dict of all CLI-provided config values.
    """
    fields = _collect_config_fields(config_class)
    field_names = {name for name, field_info in fields if not _is_list_type(field_info.annotation)}

    def decorator(func: Callable) -> Callable:
        sig = inspect.signature(func)
        original_params = list(sig.parameters.values())
        original_param_names = {p.name for p in original_params}

        # Build new parameters: config fields first, then original params
        new_params = []

        for field_name, field_info in fields:
            # Skip fields already defined in function signature (e.g., with envvar)
            if field_name in original_param_names:
                continue
            annotation = field_info.annotation
            if _is_list_type(annotation):
                continue

            flag_name = _python_name_to_cli_flag(field_name)
            help_text = field_info.description or ""

            if _is_bool_field(annotation):
                default = typer.Option(
                    None,
                    f"{flag_name}/--no-{field_name.replace('_', '-')}",
                    help=help_text,
                )
                param = inspect.Parameter(
                    field_name,
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    default=default,
                    annotation=Optional[bool],
                )
            else:
                py_type = _get_python_type(annotation)
                default = typer.Option(None, flag_name, help=help_text)
                param = inspect.Parameter(
                    field_name,
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    default=default,
                    annotation=Optional[py_type],
                )
            new_params.append(param)

        # Add original params, excluding config_overrides (will be injected)
        for param in original_params:
            if param.name != "config_overrides":
                new_params.append(param)

        new_sig = sig.replace(parameters=new_params)

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            config_overrides = {}
            for key in list(kwargs.keys()):
                if key in field_names:
                    if kwargs[key] is not None:
                        config_overrides[key] = kwargs[key]
                    # Only delete if not an explicitly declared parameter
                    if key not in original_param_names:
                        del kwargs[key]

            kwargs["config_overrides"] = config_overrides
            return func(*args, **kwargs)

        wrapper.__signature__ = new_sig
        return wrapper

    return decorator