OpenBB / cli /openbb_cli /controllers /base_platform_controller.py
CatPtain's picture
Upload 70 files
b93364a verified
"""Platform Equity Controller."""
import os
from functools import partial, update_wrapper
from types import MethodType
from typing import Dict, List, Optional
import pandas as pd
from openbb import obb
from openbb_charting.core.openbb_figure import OpenBBFigure
from openbb_cli.argparse_translator.argparse_class_processor import (
ArgparseClassProcessor,
)
from openbb_cli.config.menu_text import MenuText
from openbb_cli.controllers.base_controller import BaseController
from openbb_cli.controllers.utils import export_data, print_rich_table
from openbb_cli.session import Session
from openbb_core.app.model.obbject import OBBject
session = Session()
class DummyTranslation:
"""Dummy Translation for testing."""
def __init__(self):
"""Construct a Dummy Translation Class."""
self.paths = {}
self.translators = {}
class PlatformController(BaseController):
"""Platform Controller Base class."""
CHOICES_GENERATION = True
def __init__( # pylint: disable=too-many-positional-arguments
self,
name: str,
parent_path: List[str],
platform_target: Optional[type] = None,
queue: Optional[List[str]] = None,
translators: Optional[Dict] = None,
):
"""Construct a Platform based Controller."""
self.PATH = f"/{'/'.join(parent_path)}/{name}/" if parent_path else f"/{name}/"
super().__init__(queue)
self._name = name
if not (platform_target or translators):
raise ValueError("Either platform_target or translators must be provided.")
self._translated_target = (
ArgparseClassProcessor(
target_class=platform_target, reference=obb.reference["paths"] # type: ignore
)
if platform_target
else DummyTranslation()
)
self.translators = (
translators
if translators is not None
else getattr(self._translated_target, "translators", {})
)
self.paths = getattr(self._translated_target, "paths", {})
if self.translators:
self._link_obbject_to_data_processing_commands()
self._generate_commands()
self._generate_sub_controllers()
self.update_completer(self.choices_default)
def _link_obbject_to_data_processing_commands(self):
"""Link data processing commands to OBBject registry."""
for _, trl in self.translators.items():
for action in trl._parser._actions: # pylint: disable=protected-access
if action.dest == "data":
# Generate choices by combining indexed and key-based choices
action.choices = [
"OBB" + str(i)
for i in range(len(session.obbject_registry.obbjects))
] + [
obbject.extra["register_key"]
for obbject in session.obbject_registry.obbjects
if "register_key" in obbject.extra
]
action.type = str
action.nargs = None
def _intersect_data_processing_commands(self, ns_parser):
"""Intersect data processing commands and change the obbject id into an actual obbject."""
if hasattr(ns_parser, "data"):
if "OBB" in ns_parser.data:
ns_parser.data = int(ns_parser.data.replace("OBB", ""))
if (ns_parser.data in range(len(session.obbject_registry.obbjects))) or (
ns_parser.data in session.obbject_registry.obbject_keys
):
obbject = session.obbject_registry.get(ns_parser.data)
if obbject and isinstance(obbject, OBBject):
setattr(ns_parser, "data", obbject.results)
return ns_parser
def _generate_sub_controllers(self):
"""Handle paths."""
for path, value in self.paths.items():
if value == "path":
continue
sub_menu_translators = {}
choices_commands = []
for translator_name, translator in self.translators.items():
if f"{self._name}_{path}" in translator_name:
new_name = translator_name.replace(f"{self._name}_{path}_", "")
sub_menu_translators[new_name] = translator
choices_commands.append(new_name)
if translator_name in self.CHOICES_COMMANDS:
self.CHOICES_COMMANDS.remove(translator_name)
# Create the sub controller as a new class
class_name = f"{self._name.capitalize()}{path.capitalize()}Controller"
SubController = type(
class_name,
(PlatformController,),
{
"CHOICES_GENERATION": True,
# "CHOICES_MENUS": [],
"CHOICES_COMMANDS": choices_commands,
},
)
self._generate_controller_call(
controller=SubController,
name=path,
parent_path=self.path,
translators=sub_menu_translators,
)
def _generate_commands(self):
"""Generate commands."""
for name, translator in self.translators.items():
# Prepare the translator name to create a command call in the controller
new_name = name.replace(f"{self._name}_", "")
self._generate_command_call(name=new_name, translator=translator)
def _generate_command_call(self, name, translator):
"""Generate command call."""
def method(self, other_args: List[str], translator=translator):
"""Call the translator."""
parser = translator.parser
if ns_parser := self.parse_known_args_and_warn(
parser=parser,
other_args=other_args,
export_allowed="raw_data_and_figures",
):
try:
ns_parser = self._intersect_data_processing_commands(ns_parser)
export = hasattr(ns_parser, "export") and ns_parser.export
store_obbject = (
hasattr(ns_parser, "register_obbject")
and ns_parser.register_obbject
)
obbject = translator.execute_func(parsed_args=ns_parser)
df: pd.DataFrame = pd.DataFrame()
fig: Optional[OpenBBFigure] = None
title = f"{self.PATH}{translator.func.__name__}"
if obbject:
if isinstance(obbject, list):
obbject = OBBject(results=obbject)
if isinstance(obbject, OBBject):
if (
session.max_obbjects_exceeded()
and obbject.results
and store_obbject
):
session.obbject_registry.remove()
session.console.print(
"[yellow]Maximum number of OBBjects reached. The oldest entry was removed.[yellow]"
)
# use the obbject to store the command so we can display it later on results
obbject.extra["command"] = f"{title} {' '.join(other_args)}"
# if there is a registry key in the parser, store to the obbject
if (
hasattr(ns_parser, "register_key")
and ns_parser.register_key
):
if (
ns_parser.register_key
not in session.obbject_registry.obbject_keys
):
obbject.extra["register_key"] = str(
ns_parser.register_key
)
else:
session.console.print(
f"[yellow]Key `{ns_parser.register_key}` already exists in the registry."
"The `OBBject` was kept without the key.[/yellow]"
)
if store_obbject:
# store the obbject in the registry
register_result = session.obbject_registry.register(
obbject
)
# we need to force to re-link so that the new obbject
# is immediately available for data processing commands
self._link_obbject_to_data_processing_commands()
# also update the completer
self.update_completer(self.choices_default)
if (
session.settings.SHOW_MSG_OBBJECT_REGISTRY
and register_result
):
session.console.print(
"Added `OBBject` to cached results."
)
# making the dataframe available either for printing or exporting
df = obbject.to_dataframe()
if hasattr(ns_parser, "chart") and ns_parser.chart:
fig = obbject.chart.fig if obbject.chart else None
if not export:
obbject.show()
elif session.settings.USE_INTERACTIVE_DF and not export:
obbject.charting.table()
else:
if isinstance(df.columns, pd.RangeIndex):
df.columns = [str(i) for i in df.columns]
print_rich_table(
df=df, show_index=True, title=title, export=export
)
elif isinstance(obbject, dict):
df = pd.DataFrame.from_dict(obbject, orient="columns")
print_rich_table(
df=df, show_index=True, title=title, export=export
)
elif not isinstance(obbject, OBBject):
session.console.print(obbject)
if export and not df.empty:
sheet_name = getattr(ns_parser, "sheet_name", None)
if sheet_name and isinstance(sheet_name, list):
sheet_name = sheet_name[0]
export_data(
export_type=",".join(ns_parser.export),
dir_path=os.path.dirname(os.path.abspath(__file__)),
func_name=translator.func.__name__,
df=df,
sheet_name=sheet_name,
figure=fig,
)
elif export and df.empty:
session.console.print("[yellow]No data to export.[/yellow]")
except Exception as e:
session.console.print(f"[red]{e}[/]\n")
return
# Bind the method to the class
bound_method = MethodType(method, self)
# Update the wrapper and set the attribute
bound_method = update_wrapper( # type: ignore
partial(bound_method, translator=translator), method
)
setattr(self, f"call_{name}", bound_method)
def _generate_controller_call(self, controller, name, parent_path, translators):
"""Generate controller call."""
def method(self, _, controller, name, parent_path, translators):
"""Call the controller."""
self.queue = self.load_class(
class_ins=controller,
name=name,
parent_path=parent_path,
translators=translators,
queue=self.queue,
)
# Bind the method to the class
bound_method = MethodType(method, self)
# Update the wrapper and set the attribute
bound_method = update_wrapper( # type: ignore
partial(
bound_method,
name=name,
parent_path=parent_path,
translators=translators,
controller=controller,
),
method,
)
setattr(self, f"call_{name}", bound_method)
def _get_command_description(self, command: str) -> str:
"""Get command description."""
command_description = (
obb.reference["paths"] # type: ignore
.get(f"{self.PATH}{command}", {})
.get("description", "")
)
if not command_description:
trl = self.translators.get(
f"{self._name}_{command}"
) or self.translators.get(command)
if trl and hasattr(trl, "parser"):
command_description = trl.parser.description
return command_description.split(".")[0].lower()
def _get_menu_description(self, menu: str) -> str:
"""Get menu description."""
def _get_sub_menu_commands():
"""Get sub menu commands."""
sub_path = f"{self.PATH[1:].replace('/','_')}{menu}"
commands = []
for trl in self.translators:
if sub_path in trl:
commands.append(trl.replace(f"{sub_path}_", ""))
return commands
menu_description = (
obb.reference["routers"] # type: ignore
.get(f"{self.PATH}{menu}", {})
.get("description", "")
) or ""
if menu_description:
return menu_description.split(".")[0].lower()
# If no description is found, return the sub menu commands
return ", ".join(_get_sub_menu_commands())
def print_help(self):
"""Print help."""
mt = MenuText(self.PATH)
if self.CHOICES_MENUS:
for menu in self.CHOICES_MENUS:
description = self._get_menu_description(menu)
mt.add_menu(name=menu, description=description)
if self.CHOICES_COMMANDS:
mt.add_raw("\n")
if self.CHOICES_COMMANDS:
for command in self.CHOICES_COMMANDS:
command_description = self._get_command_description(command)
mt.add_cmd(
name=command.replace(f"{self._name}_", ""),
description=command_description,
)
if session.obbject_registry.obbjects:
mt.add_info("\nCached Results")
for key, value in list(session.obbject_registry.all.items())[
: session.settings.N_TO_DISPLAY_OBBJECT_REGISTRY
]:
mt.add_raw(
f"[yellow]OBB{key}[/yellow]: {value['command']}",
left_spacing=True,
)
session.console.print(text=mt.menu_text, menu=self.PATH)
if mt.warnings:
session.console.print("")
for w in mt.warnings:
w_str = str(w).replace("{", "").replace("}", "").replace("'", "")
session.console.print(f"[yellow]{w_str}[/yellow]")
session.console.print("")