File size: 6,018 Bytes
a2afe2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Router testers."""

import ast
import os
from typing import Any, Dict, List, Optional

from openbb_core.app.provider_interface import ProviderInterface

from extensions.tests.utils.helpers import (
    collect_router_functions,
    collect_routers,
    find_decorator,
    find_missing_router_function_models,
    get_all_fields,
    get_decorator_details,
    get_required_fields,
    import_routers,
    parse_example_string,
)


def check_router_function_models() -> List[str]:
    """Check if the models in the router functions exist in the provider interface map."""
    pi = ProviderInterface()
    pi_map = pi.map
    routers = collect_routers("extensions")
    loaded_routers = import_routers(routers)
    router_functions = collect_router_functions(loaded_routers)
    missing_models = find_missing_router_function_models(router_functions, pi_map)

    return missing_models


def check_router_model_functions_signature() -> List[str]:
    """Check if the router model functions have the correct signature."""
    expected_args = ["cc", "provider_choices", "standard_params", "extra_params"]
    expected_return_type = "OBBject"
    missing_args: List[str] = []
    missing_return_type: List[str] = []

    routers = collect_routers("extensions")
    loaded_routers = import_routers(routers)
    router_functions = collect_router_functions(loaded_routers)

    for router_name, functions in router_functions.items():
        for function in functions:
            decorator = find_decorator(
                os.path.join(*router_name.split(".")) + ".py",
                function.__name__,
            )
            if decorator:
                if "POST" in decorator or "GET" in decorator:
                    continue
                args = list(function.__code__.co_varnames)

                if (
                    args
                    and not all(arg in args for arg in expected_args)
                    and "model" in decorator
                ):
                    missing_args.append(
                        f"{function.__name__} in {router_name} missing expected args: {expected_args}"
                    )
                if expected_return_type not in str(function.__annotations__["return"]):
                    missing_return_type.append(
                        f"{function.__name__} in {router_name} "
                        f"doesn't have the expected return type: {expected_return_type}"
                    )

    return missing_args + missing_return_type


def check_general(
    keywords: Dict, examples: List, router_name: str, function: Any
) -> List[str]:
    """Check for general violations in the router command examples."""
    general_violation: List[str] = []

    # Check if the endpoint has examples
    if "examples" not in keywords or not examples:
        general_violation.append(
            f"'{router_name}' > '{function.__name__}': missing examples"
        )
        return general_violation

    return general_violation


def check_api(
    examples: str, router_name: str, model: Optional[str], function: Any
) -> List[str]:
    """Check for API examples."""
    api_example_violation: List[str] = []
    parsed_examples = parse_example_string(examples)
    if model and "APIEx" in parsed_examples:
        required_fields = set(get_required_fields(model.strip("'")))
        all_fields = get_all_fields(model.strip("'"))
        all_fields.append("provider")
        required_fields_met = False

        for api_example in parsed_examples["APIEx"]:
            params = ast.literal_eval(api_example.get("params", "{}"))
            if not required_fields_met and required_fields.issubset(params.keys()):
                required_fields_met = True

            # Check for unsupported parameters
            for param in params:
                if param not in all_fields:
                    api_example_violation.append(
                        f"'{router_name}' > '{function.__name__}': param '{param}' is not supported by the command."
                    )

        # If after checking all examples, required fields are still not met
        if not required_fields_met:
            api_example_violation.append(
                f"'{router_name}' > '{function.__name__}': missing example with required fields only > {required_fields}"
            )

    return api_example_violation


def check_router_command_examples() -> List[str]:
    """Check if the router command examples satisfy criteria."""
    general_violation: List[str] = []
    api_example_violation: List[str] = []
    python_example_violation: List[str] = []

    routers = collect_routers("extensions")
    loaded_routers = import_routers(routers)
    router_functions = collect_router_functions(loaded_routers)

    for router_name, functions in router_functions.items():
        for function in functions:
            if (
                "basemodel_to_df" in function.__name__
                or "router" not in function.__module__
            ):
                continue
            decorator = find_decorator(
                os.path.join(*router_name.split(".")) + ".py",
                function.__name__,
            )
            if decorator:
                decorator_details = get_decorator_details(function)
                if decorator_details and decorator_details.name == "router.command":
                    keywords = decorator_details.kwargs or {}
                    examples = keywords.get("examples", [])
                    # General checks
                    general_violation += check_general(
                        keywords, examples, router_name, function
                    )
                    if examples:
                        # API example checks
                        model = keywords.get("model", None)
                        api_example_violation += check_api(
                            examples, router_name, model, function
                        )

    return general_violation + api_example_violation + python_example_violation