|
|
"""Platform Equity Controller.""" |
|
|
|
|
|
import os |
|
|
from functools import partial, update_wrapper |
|
|
from types import MethodType |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import pandas as pd |
|
|
from openbb import obb |
|
|
from openbb_charting.core.openbb_figure import OpenBBFigure |
|
|
from openbb_cli.argparse_translator.argparse_class_processor import ( |
|
|
ArgparseClassProcessor, |
|
|
) |
|
|
from openbb_cli.config.menu_text import MenuText |
|
|
from openbb_cli.controllers.base_controller import BaseController |
|
|
from openbb_cli.controllers.utils import export_data, print_rich_table |
|
|
from openbb_cli.session import Session |
|
|
from openbb_core.app.model.obbject import OBBject |
|
|
|
|
|
session = Session() |
|
|
|
|
|
|
|
|
class DummyTranslation: |
|
|
"""Dummy Translation for testing.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Construct a Dummy Translation Class.""" |
|
|
self.paths = {} |
|
|
self.translators = {} |
|
|
|
|
|
|
|
|
class PlatformController(BaseController): |
|
|
"""Platform Controller Base class.""" |
|
|
|
|
|
CHOICES_GENERATION = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
parent_path: List[str], |
|
|
platform_target: Optional[type] = None, |
|
|
queue: Optional[List[str]] = None, |
|
|
translators: Optional[Dict] = None, |
|
|
): |
|
|
"""Construct a Platform based Controller.""" |
|
|
self.PATH = f"/{'/'.join(parent_path)}/{name}/" if parent_path else f"/{name}/" |
|
|
super().__init__(queue) |
|
|
self._name = name |
|
|
|
|
|
if not (platform_target or translators): |
|
|
raise ValueError("Either platform_target or translators must be provided.") |
|
|
|
|
|
self._translated_target = ( |
|
|
ArgparseClassProcessor( |
|
|
target_class=platform_target, reference=obb.reference["paths"] |
|
|
) |
|
|
if platform_target |
|
|
else DummyTranslation() |
|
|
) |
|
|
self.translators = ( |
|
|
translators |
|
|
if translators is not None |
|
|
else getattr(self._translated_target, "translators", {}) |
|
|
) |
|
|
self.paths = getattr(self._translated_target, "paths", {}) |
|
|
|
|
|
if self.translators: |
|
|
self._link_obbject_to_data_processing_commands() |
|
|
self._generate_commands() |
|
|
self._generate_sub_controllers() |
|
|
self.update_completer(self.choices_default) |
|
|
|
|
|
def _link_obbject_to_data_processing_commands(self): |
|
|
"""Link data processing commands to OBBject registry.""" |
|
|
for _, trl in self.translators.items(): |
|
|
for action in trl._parser._actions: |
|
|
if action.dest == "data": |
|
|
|
|
|
action.choices = [ |
|
|
"OBB" + str(i) |
|
|
for i in range(len(session.obbject_registry.obbjects)) |
|
|
] + [ |
|
|
obbject.extra["register_key"] |
|
|
for obbject in session.obbject_registry.obbjects |
|
|
if "register_key" in obbject.extra |
|
|
] |
|
|
|
|
|
action.type = str |
|
|
action.nargs = None |
|
|
|
|
|
def _intersect_data_processing_commands(self, ns_parser): |
|
|
"""Intersect data processing commands and change the obbject id into an actual obbject.""" |
|
|
if hasattr(ns_parser, "data"): |
|
|
if "OBB" in ns_parser.data: |
|
|
ns_parser.data = int(ns_parser.data.replace("OBB", "")) |
|
|
|
|
|
if (ns_parser.data in range(len(session.obbject_registry.obbjects))) or ( |
|
|
ns_parser.data in session.obbject_registry.obbject_keys |
|
|
): |
|
|
obbject = session.obbject_registry.get(ns_parser.data) |
|
|
if obbject and isinstance(obbject, OBBject): |
|
|
setattr(ns_parser, "data", obbject.results) |
|
|
|
|
|
return ns_parser |
|
|
|
|
|
def _generate_sub_controllers(self): |
|
|
"""Handle paths.""" |
|
|
for path, value in self.paths.items(): |
|
|
if value == "path": |
|
|
continue |
|
|
|
|
|
sub_menu_translators = {} |
|
|
choices_commands = [] |
|
|
|
|
|
for translator_name, translator in self.translators.items(): |
|
|
if f"{self._name}_{path}" in translator_name: |
|
|
new_name = translator_name.replace(f"{self._name}_{path}_", "") |
|
|
sub_menu_translators[new_name] = translator |
|
|
choices_commands.append(new_name) |
|
|
|
|
|
if translator_name in self.CHOICES_COMMANDS: |
|
|
self.CHOICES_COMMANDS.remove(translator_name) |
|
|
|
|
|
|
|
|
class_name = f"{self._name.capitalize()}{path.capitalize()}Controller" |
|
|
SubController = type( |
|
|
class_name, |
|
|
(PlatformController,), |
|
|
{ |
|
|
"CHOICES_GENERATION": True, |
|
|
|
|
|
"CHOICES_COMMANDS": choices_commands, |
|
|
}, |
|
|
) |
|
|
|
|
|
self._generate_controller_call( |
|
|
controller=SubController, |
|
|
name=path, |
|
|
parent_path=self.path, |
|
|
translators=sub_menu_translators, |
|
|
) |
|
|
|
|
|
def _generate_commands(self): |
|
|
"""Generate commands.""" |
|
|
for name, translator in self.translators.items(): |
|
|
|
|
|
new_name = name.replace(f"{self._name}_", "") |
|
|
|
|
|
self._generate_command_call(name=new_name, translator=translator) |
|
|
|
|
|
def _generate_command_call(self, name, translator): |
|
|
"""Generate command call.""" |
|
|
|
|
|
def method(self, other_args: List[str], translator=translator): |
|
|
"""Call the translator.""" |
|
|
parser = translator.parser |
|
|
|
|
|
if ns_parser := self.parse_known_args_and_warn( |
|
|
parser=parser, |
|
|
other_args=other_args, |
|
|
export_allowed="raw_data_and_figures", |
|
|
): |
|
|
try: |
|
|
ns_parser = self._intersect_data_processing_commands(ns_parser) |
|
|
export = hasattr(ns_parser, "export") and ns_parser.export |
|
|
store_obbject = ( |
|
|
hasattr(ns_parser, "register_obbject") |
|
|
and ns_parser.register_obbject |
|
|
) |
|
|
|
|
|
obbject = translator.execute_func(parsed_args=ns_parser) |
|
|
df: pd.DataFrame = pd.DataFrame() |
|
|
fig: Optional[OpenBBFigure] = None |
|
|
title = f"{self.PATH}{translator.func.__name__}" |
|
|
|
|
|
if obbject: |
|
|
if isinstance(obbject, list): |
|
|
obbject = OBBject(results=obbject) |
|
|
|
|
|
if isinstance(obbject, OBBject): |
|
|
if ( |
|
|
session.max_obbjects_exceeded() |
|
|
and obbject.results |
|
|
and store_obbject |
|
|
): |
|
|
session.obbject_registry.remove() |
|
|
session.console.print( |
|
|
"[yellow]Maximum number of OBBjects reached. The oldest entry was removed.[yellow]" |
|
|
) |
|
|
|
|
|
|
|
|
obbject.extra["command"] = f"{title} {' '.join(other_args)}" |
|
|
|
|
|
if ( |
|
|
hasattr(ns_parser, "register_key") |
|
|
and ns_parser.register_key |
|
|
): |
|
|
if ( |
|
|
ns_parser.register_key |
|
|
not in session.obbject_registry.obbject_keys |
|
|
): |
|
|
obbject.extra["register_key"] = str( |
|
|
ns_parser.register_key |
|
|
) |
|
|
else: |
|
|
session.console.print( |
|
|
f"[yellow]Key `{ns_parser.register_key}` already exists in the registry." |
|
|
"The `OBBject` was kept without the key.[/yellow]" |
|
|
) |
|
|
|
|
|
if store_obbject: |
|
|
|
|
|
register_result = session.obbject_registry.register( |
|
|
obbject |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self._link_obbject_to_data_processing_commands() |
|
|
|
|
|
self.update_completer(self.choices_default) |
|
|
|
|
|
if ( |
|
|
session.settings.SHOW_MSG_OBBJECT_REGISTRY |
|
|
and register_result |
|
|
): |
|
|
session.console.print( |
|
|
"Added `OBBject` to cached results." |
|
|
) |
|
|
|
|
|
|
|
|
df = obbject.to_dataframe() |
|
|
|
|
|
if hasattr(ns_parser, "chart") and ns_parser.chart: |
|
|
fig = obbject.chart.fig if obbject.chart else None |
|
|
if not export: |
|
|
obbject.show() |
|
|
elif session.settings.USE_INTERACTIVE_DF and not export: |
|
|
obbject.charting.table() |
|
|
else: |
|
|
if isinstance(df.columns, pd.RangeIndex): |
|
|
df.columns = [str(i) for i in df.columns] |
|
|
|
|
|
print_rich_table( |
|
|
df=df, show_index=True, title=title, export=export |
|
|
) |
|
|
|
|
|
elif isinstance(obbject, dict): |
|
|
df = pd.DataFrame.from_dict(obbject, orient="columns") |
|
|
print_rich_table( |
|
|
df=df, show_index=True, title=title, export=export |
|
|
) |
|
|
|
|
|
elif not isinstance(obbject, OBBject): |
|
|
session.console.print(obbject) |
|
|
|
|
|
if export and not df.empty: |
|
|
sheet_name = getattr(ns_parser, "sheet_name", None) |
|
|
if sheet_name and isinstance(sheet_name, list): |
|
|
sheet_name = sheet_name[0] |
|
|
|
|
|
export_data( |
|
|
export_type=",".join(ns_parser.export), |
|
|
dir_path=os.path.dirname(os.path.abspath(__file__)), |
|
|
func_name=translator.func.__name__, |
|
|
df=df, |
|
|
sheet_name=sheet_name, |
|
|
figure=fig, |
|
|
) |
|
|
elif export and df.empty: |
|
|
session.console.print("[yellow]No data to export.[/yellow]") |
|
|
|
|
|
except Exception as e: |
|
|
session.console.print(f"[red]{e}[/]\n") |
|
|
return |
|
|
|
|
|
|
|
|
bound_method = MethodType(method, self) |
|
|
|
|
|
|
|
|
bound_method = update_wrapper( |
|
|
partial(bound_method, translator=translator), method |
|
|
) |
|
|
setattr(self, f"call_{name}", bound_method) |
|
|
|
|
|
def _generate_controller_call(self, controller, name, parent_path, translators): |
|
|
"""Generate controller call.""" |
|
|
|
|
|
def method(self, _, controller, name, parent_path, translators): |
|
|
"""Call the controller.""" |
|
|
self.queue = self.load_class( |
|
|
class_ins=controller, |
|
|
name=name, |
|
|
parent_path=parent_path, |
|
|
translators=translators, |
|
|
queue=self.queue, |
|
|
) |
|
|
|
|
|
|
|
|
bound_method = MethodType(method, self) |
|
|
|
|
|
|
|
|
bound_method = update_wrapper( |
|
|
partial( |
|
|
bound_method, |
|
|
name=name, |
|
|
parent_path=parent_path, |
|
|
translators=translators, |
|
|
controller=controller, |
|
|
), |
|
|
method, |
|
|
) |
|
|
setattr(self, f"call_{name}", bound_method) |
|
|
|
|
|
def _get_command_description(self, command: str) -> str: |
|
|
"""Get command description.""" |
|
|
command_description = ( |
|
|
obb.reference["paths"] |
|
|
.get(f"{self.PATH}{command}", {}) |
|
|
.get("description", "") |
|
|
) |
|
|
|
|
|
if not command_description: |
|
|
trl = self.translators.get( |
|
|
f"{self._name}_{command}" |
|
|
) or self.translators.get(command) |
|
|
if trl and hasattr(trl, "parser"): |
|
|
command_description = trl.parser.description |
|
|
|
|
|
return command_description.split(".")[0].lower() |
|
|
|
|
|
def _get_menu_description(self, menu: str) -> str: |
|
|
"""Get menu description.""" |
|
|
|
|
|
def _get_sub_menu_commands(): |
|
|
"""Get sub menu commands.""" |
|
|
sub_path = f"{self.PATH[1:].replace('/','_')}{menu}" |
|
|
commands = [] |
|
|
for trl in self.translators: |
|
|
if sub_path in trl: |
|
|
commands.append(trl.replace(f"{sub_path}_", "")) |
|
|
return commands |
|
|
|
|
|
menu_description = ( |
|
|
obb.reference["routers"] |
|
|
.get(f"{self.PATH}{menu}", {}) |
|
|
.get("description", "") |
|
|
) or "" |
|
|
if menu_description: |
|
|
return menu_description.split(".")[0].lower() |
|
|
|
|
|
|
|
|
return ", ".join(_get_sub_menu_commands()) |
|
|
|
|
|
def print_help(self): |
|
|
"""Print help.""" |
|
|
mt = MenuText(self.PATH) |
|
|
|
|
|
if self.CHOICES_MENUS: |
|
|
for menu in self.CHOICES_MENUS: |
|
|
description = self._get_menu_description(menu) |
|
|
mt.add_menu(name=menu, description=description) |
|
|
|
|
|
if self.CHOICES_COMMANDS: |
|
|
mt.add_raw("\n") |
|
|
|
|
|
if self.CHOICES_COMMANDS: |
|
|
for command in self.CHOICES_COMMANDS: |
|
|
command_description = self._get_command_description(command) |
|
|
mt.add_cmd( |
|
|
name=command.replace(f"{self._name}_", ""), |
|
|
description=command_description, |
|
|
) |
|
|
|
|
|
if session.obbject_registry.obbjects: |
|
|
mt.add_info("\nCached Results") |
|
|
for key, value in list(session.obbject_registry.all.items())[ |
|
|
: session.settings.N_TO_DISPLAY_OBBJECT_REGISTRY |
|
|
]: |
|
|
mt.add_raw( |
|
|
f"[yellow]OBB{key}[/yellow]: {value['command']}", |
|
|
left_spacing=True, |
|
|
) |
|
|
|
|
|
session.console.print(text=mt.menu_text, menu=self.PATH) |
|
|
|
|
|
if mt.warnings: |
|
|
session.console.print("") |
|
|
for w in mt.warnings: |
|
|
w_str = str(w).replace("{", "").replace("}", "").replace("'", "") |
|
|
session.console.print(f"[yellow]{w_str}[/yellow]") |
|
|
session.console.print("") |
|
|
|