File size: 3,160 Bytes
b93364a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Utilities for argparse_translator module."""

from argparse import Action, ArgumentParser
from typing import List, Optional, Tuple


def in_group(parser: ArgumentParser, argument_name: str, group_title: str) -> bool:
    """Check if an argument is in a group of an ArgumentParser."""
    for action_group in parser._action_groups:  # pylint: disable=protected-access
        if action_group.title == group_title:
            for (
                action
            ) in action_group._group_actions:  # pylint: disable=protected-access
                opts = action.option_strings
                if (opts and opts[0] == argument_name) or action.dest == argument_name:
                    return True
    return False


def remove_argument(parser: ArgumentParser, argument_name: str) -> List[Optional[str]]:
    """Remove an argument from an ArgumentParser."""
    groups_w_arg = []

    # remove the argument from the parser
    for action in parser._actions:  # pylint: disable=protected-access
        opts = action.option_strings
        if (opts and opts[0] == argument_name) or action.dest == argument_name:
            parser._remove_action(action)  # pylint: disable=protected-access
            break

    # remove from all groups
    for action_group in parser._action_groups:  # pylint: disable=protected-access
        for action in action_group._group_actions:  # pylint: disable=protected-access
            opts = action.option_strings
            if (opts and opts[0] == argument_name) or action.dest == argument_name:
                action_group._group_actions.remove(  # pylint: disable=protected-access
                    action
                )
                groups_w_arg.append(action_group.title)

    # remove from _action_groups dict
    parser._option_string_actions.pop(  # pylint: disable=protected-access
        f"--{argument_name}", None
    )

    return groups_w_arg


def get_argument_choices(parser: ArgumentParser, argument_name: str) -> Tuple:
    """Get the choices of an argument from an ArgumentParser."""
    for action in parser._actions:  # pylint: disable=protected-access
        opts = action.option_strings
        if (opts and opts[0] == argument_name) or action.dest == argument_name:
            return tuple(action.choices or ())
    return ()


def get_argument_optional_choices(parser: ArgumentParser, argument_name: str) -> bool:
    """Get the optional_choices attribute of an argument from an ArgumentParser."""
    for action in parser._actions:  # pylint: disable=protected-access
        opts = action.option_strings
        if (
            (opts and opts[0] == argument_name)
            or action.dest == argument_name
            and hasattr(action, "optional_choices")
        ):
            return (
                action.optional_choices  # type: ignore[attr-defined] # this is a custom attribute
            )
    return False


def set_optional_choices(action: Action, optional_choices: bool):
    """Set the optional_choices attribute of an action."""
    if not hasattr(action, "optional_choices") and optional_choices:
        setattr(action, "optional_choices", optional_choices)