import gradio as gr
from util import css, js, Color
import pandas as pd
from trading_floor import names, lastnames, short_model_names
import plotly.express as px
from accounts import Account
from database import read_log
mapper = {
"trace": Color.WHITE,
"agent": Color.CYAN,
"function": Color.GREEN,
"generation": Color.YELLOW,
"response": Color.MAGENTA,
"account": Color.RED,
}
class Trader:
def __init__(self, name: str, lastname: str, model_name: str):
self.name = name
self.lastname = lastname
self.model_name = model_name
self.account = Account.get(name)
def reload(self):
self.account = Account.get(self.name)
def get_title(self) -> str:
return f"
{self.name} ({self.model_name}) - {self.lastname}
"
def get_strategy(self) -> str:
return self.account.get_strategy()
def get_portfolio_value_df(self) -> pd.DataFrame:
df = pd.DataFrame(self.account.portfolio_value_time_series, columns=["datetime", "value"])
df["datetime"] = pd.to_datetime(df["datetime"])
return df
def get_portfolio_value_chart(self):
df = self.get_portfolio_value_df()
fig = px.line(df, x="datetime", y="value")
margin = dict(l=40, r=20, t=20, b=40)
fig.update_layout(
height=300,
margin=margin,
xaxis_title=None,
yaxis_title=None,
paper_bgcolor="#bbb",
plot_bgcolor="#dde",
)
fig.update_xaxes(tickformat="%m/%d", tickangle=45, tickfont=dict(size=8))
fig.update_yaxes(tickfont=dict(size=8), tickformat=",.0f")
return fig
def get_holdings_df(self) -> pd.DataFrame:
"""Convert holdings to DataFrame for display"""
holdings = self.account.get_holdings()
if not holdings:
return pd.DataFrame(columns=["Symbol", "Quantity"])
df = pd.DataFrame(
[{"Symbol": symbol, "Quantity": quantity} for symbol, quantity in holdings.items()]
)
return df
def get_transactions_df(self) -> pd.DataFrame:
"""Convert transactions to DataFrame for display"""
transactions = self.account.list_transactions()
if not transactions:
return pd.DataFrame(columns=["Timestamp", "Symbol", "Quantity", "Price", "Rationale"])
return pd.DataFrame(transactions)
def get_portfolio_value(self) -> str:
"""Calculate total portfolio value based on current prices"""
portfolio_value = self.account.calculate_portfolio_value() or 0.0
pnl = self.account.calculate_profit_loss(portfolio_value) or 0.0
color = "green" if pnl >= 0 else "red"
emoji = "⬆" if pnl >= 0 else "⬇"
return f"${portfolio_value:,.0f} {emoji} ${pnl:,.0f}
"
def get_logs(self, previous=None) -> str:
logs = read_log(self.name, last_n=13)
response = ""
for log in logs:
timestamp, type, message = log
color = mapper.get(type, Color.WHITE).value
response += f"{timestamp} : [{type}] {message}
"
response = f"{response}
"
if response != previous:
return response
return gr.update()
class TraderView:
def __init__(self, trader: Trader):
self.trader = trader
self.portfolio_value = None
self.chart = None
self.holdings_table = None
self.transactions_table = None
def make_ui(self):
with gr.Column():
gr.HTML(self.trader.get_title())
with gr.Row():
self.portfolio_value = gr.HTML(self.trader.get_portfolio_value)
with gr.Row():
self.chart = gr.Plot(
self.trader.get_portfolio_value_chart, container=True, show_label=False
)
with gr.Row(variant="panel"):
self.log = gr.HTML(self.trader.get_logs)
with gr.Row():
self.holdings_table = gr.Dataframe(
value=self.trader.get_holdings_df,
label="Holdings",
headers=["Symbol", "Quantity"],
row_count=(5, "dynamic"),
col_count=2,
max_height=300,
elem_classes=["dataframe-fix-small"],
)
with gr.Row():
self.transactions_table = gr.Dataframe(
value=self.trader.get_transactions_df,
label="Recent Transactions",
headers=["Timestamp", "Symbol", "Quantity", "Price", "Rationale"],
row_count=(5, "dynamic"),
col_count=5,
max_height=300,
elem_classes=["dataframe-fix"],
)
timer = gr.Timer(value=120)
timer.tick(
fn=self.refresh,
inputs=[],
outputs=[
self.portfolio_value,
self.chart,
self.holdings_table,
self.transactions_table,
],
show_progress="hidden",
queue=False,
)
log_timer = gr.Timer(value=0.5)
log_timer.tick(
fn=self.trader.get_logs,
inputs=[self.log],
outputs=[self.log],
show_progress="hidden",
queue=False,
)
def refresh(self):
self.trader.reload()
return (
self.trader.get_portfolio_value(),
self.trader.get_portfolio_value_chart(),
self.trader.get_holdings_df(),
self.trader.get_transactions_df(),
)
# Main UI construction
def create_ui():
"""Create the main Gradio UI for the trading simulation"""
traders = [
Trader(trader_name, lastname, model_name)
for trader_name, lastname, model_name in zip(names, lastnames, short_model_names)
]
trader_views = [TraderView(trader) for trader in traders]
with gr.Blocks(
title="Traders", css=css, js=js, theme=gr.themes.Default(primary_hue="sky"), fill_width=True
) as ui:
with gr.Row():
for trader_view in trader_views:
trader_view.make_ui()
return ui
if __name__ == "__main__":
ui = create_ui()
ui.launch(inbrowser=True)