Trade / app.py
Fangzhi Xu
Config
006d68a
import gradio as gr
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import os
import plotly.graph_objects as go
import pandas as pd
matplotlib.use('Agg')
class TradeArenaEnv_Deterministic:
"""
Odyssey Arena - AI Trading Environment (Deterministic version)
"""
def __init__(self, cfg):
self.num_days = cfg["num_days"]
self.stocks = cfg["stocks"]
self.variables = cfg["variables"]
self.dependency_matrix = np.array(cfg["dependency_matrix"])
self.initial_prices = np.array(cfg["initial_prices"])
self.initial_variables = np.array(cfg["initial_variables"])
self.timeline = cfg["timeline"]
self.price_noise_scale = cfg.get("price_noise_scale", 0.0)
self.initial_cash = cfg.get("initial_cash", 10000.0)
self.reset()
def reset(self):
self.t = 0
self.cash = self.initial_cash
self.positions = np.zeros(len(self.stocks), dtype=np.float64)
self.prices = self.initial_prices.copy().astype(np.float64)
self.variables_state = self.initial_variables.copy().astype(np.float64)
self.next_day_news = self.timeline.get("day_1", None)
return self._get_observation()
def _get_observation(self):
obs = {
"day": self.t,
"prices": {s: float(p) for s, p in zip(self.stocks, self.prices)},
"cash": float(self.cash),
"positions": {s: int(pos) for s, pos in zip(self.stocks, self.positions)},
"total_value": float(self.cash + np.sum(self.positions * self.prices)),
"news_next_day": self.next_day_news["variable_changes"] if self.next_day_news else None,
"news_next_day_text": self.next_day_news["news_text"] if self.next_day_news else None
}
return obs
def step(self, action):
assert isinstance(action, dict)
# Execute sells first
for stock, qty in action.get("sell", {}).items():
idx = self.stocks.index(stock)
qty = int(qty)
qty = min(qty, self.positions[idx])
revenue = self.prices[idx] * qty
self.positions[idx] -= qty
self.cash += revenue
# Then buys
for stock, qty in action.get("buy", {}).items():
idx = self.stocks.index(stock)
qty = int(qty)
cost = self.prices[idx] * qty
if cost <= self.cash:
self.positions[idx] += qty
self.cash -= cost
# Advance one day
self.t += 1
done = self.t >= self.num_days
# Update variable states & prices
if not done:
news_today = self.timeline.get(f"day_{self.t}", None)
if news_today:
deltas = np.array(news_today["variable_changes"])
self.variables_state += deltas
self._update_prices_from_variables(deltas)
# Prepare next day's news
self.next_day_news = self.timeline.get(f"day_{self.t + 1}", None) if not done else None
reward = self._compute_reward()
obs = self._get_observation()
return obs, reward, done, {}
def _update_prices_from_variables(self, delta_vars):
delta_price = self.dependency_matrix @ delta_vars
noise = np.zeros_like(delta_price) if self.price_noise_scale == 0 else np.random.normal(
0, self.price_noise_scale, len(self.stocks)
)
self.prices += delta_price + noise
self.prices = np.clip(self.prices, 0.1, None)
def _compute_reward(self):
total_value = self.cash + np.sum(self.positions * self.prices)
return round(float(total_value), 2)
# Default configuration
DEFAULT_CONFIG = {
"num_days": 30,
"stocks": ["TECH", "ENERGY", "FINANCE"],
"variables": ["interest_rate", "oil_price", "market_sentiment"],
"dependency_matrix": [
[-5, 2, 3],
[1, 8, 2],
[-3, 1, 4]
],
"initial_prices": [100, 80, 120],
"initial_variables": [0, 0, 0],
"initial_cash": 10000,
"price_noise_scale": 0,
"timeline": {
"day_1": {
"variable_changes": [0.1, -0.2, 0.3],
"news_text": "Federal Reserve hints at rate increase; Oil prices drop on oversupply concerns"
},
"day_2": {
"variable_changes": [-0.1, 0.3, 0.2],
"news_text": "Tech sector shows strong earnings; Energy stocks rally on production cuts"
},
"day_3": {
"variable_changes": [0.2, 0.1, -0.1],
"news_text": "Market sentiment cautious amid geopolitical tensions"
},
"day_4": {
"variable_changes": [0.0, 0.2, 0.1],
"news_text": "Stable interest rates; Energy sector momentum continues"
},
"day_5": {
"variable_changes": [-0.2, -0.1, 0.0],
"news_text": "Rate cut speculation; Market consolidation"
}
}
}
# ===== ๆ–ฐๅขž: config ็›ฎๅฝ•ๆ”ฏๆŒ =====
def list_config_files():
config_dir = "config"
if not os.path.exists(config_dir):
return []
return [f for f in os.listdir(config_dir) if f.endswith(".json")]
def load_config_from_file(filename):
"""ๅŠ ่ฝฝconfig็›ฎๅฝ•ไธ‹็š„jsonๆ–‡ไปถๅˆฐ่พ“ๅ…ฅๆก†"""
try:
path = os.path.join("config", filename)
with open(path, "r") as f:
cfg = json.load(f)
return json.dumps(cfg, indent=2)
except Exception as e:
return f"โŒ Failed to load {filename}: {str(e)}"
# Global state
env = None
history = []
def initialize_env(config_file=None):
global env, history
if config_file is not None and config_file.strip():
try:
config = json.loads(config_file)
except:
return "โŒ Invalid JSON file", None, None, None, None
else:
config = DEFAULT_CONFIG
env = TradeArenaEnv_Deterministic(config)
obs = env.reset()
# Initialize history
history = [{
'day': obs['day'],
'total_value': obs['total_value'],
**obs['prices']
}]
status = f"โœ… Session initialized!\n๐Ÿ“… Day: {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
return (
status,
create_portfolio_display(obs),
create_news_display(obs),
create_price_chart(),
create_value_chart()
)
def create_portfolio_display(obs):
data = []
for stock in env.stocks:
data.append({
'Stock': stock,
'Price': f"${obs['prices'][stock]:.2f}",
'Holdings': obs['positions'][stock],
'Value': f"${obs['prices'][stock] * obs['positions'][stock]:.2f}"
})
df = pd.DataFrame(data)
return df
def create_news_display(obs):
if obs['news_next_day_text']:
news_html = f"""
<div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 20px; border-radius: 10px; color: white; margin: 10px 0;'>
<h3 style='margin-top: 0;'>๐Ÿ“ฐ Next Day News</h3>
<p style='font-size: 16px; line-height: 1.6;'>{obs['news_next_day_text']}</p>
"""
if obs['news_next_day']:
news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b><br/>"
for i, var in enumerate(env.variables):
change = obs['news_next_day'][i]
news_html += f"โ€ข {var}: <b>{'+' if change > 0 else ''}{change}</b><br/>"
news_html += "</p>"
news_html += "</div>"
return news_html
else:
return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px; text-align: center;'>๐Ÿ“ญ No more news available</div>"
# def create_price_chart():
# if len(history) <= 1:
# fig, ax = plt.subplots(figsize=(10, 6))
# ax.text(0.5, 0.5, 'Trade to see price history',
# ha='center', va='center', fontsize=14, color='gray')
# ax.axis('off')
# return fig
# df = pd.DataFrame(history)
# fig, ax = plt.subplots(figsize=(10, 6))
# colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
# for i, stock in enumerate(env.stocks):
# ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock)
# ax.set_xlabel('Day')
# ax.set_ylabel('Price ($)')
# ax.set_title('Stock Price History')
# ax.legend()
# ax.grid(True, alpha=0.3)
# return fig
# def create_price_chart():
# """Create individual price chart for each stock"""
# if len(history) <= 1:
# fig, axs = plt.subplots(1, 1, figsize=(10, 6))
# axs.text(0.5, 0.5, 'Trade to see price history',
# ha='center', va='center', fontsize=14, color='gray')
# axs.axis('off')
# return fig
# df = pd.DataFrame(history)
# num_stocks = len(env.stocks)
# fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True)
# # ๅฆ‚ๆžœๅชๆœ‰ไธ€ไธช่‚ก็ฅจ๏ผŒaxsไธๆ˜ฏๆ•ฐ็ป„๏ผŒ้œ€่ฆๅค„็†
# if num_stocks == 1:
# axs = [axs]
# colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
# for i, stock in enumerate(env.stocks):
# ax = axs[i]
# ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock)
# ax.set_ylabel(f'{stock} ($)')
# ax.set_title(f'{stock} Price History')
# ax.legend(loc='best', framealpha=0.8)
# ax.grid(True, alpha=0.3)
# axs[-1].set_xlabel('Day')
# plt.tight_layout()
# return fig
def create_price_chart():
"""Create stock price chart using Plotly"""
if len(history) <= 1:
# ๆฒกๆœ‰ไบคๆ˜“ๅކๅฒๆ—ถ๏ผŒ่ฟ”ๅ›ž็ฉบ็™ฝๅ›พ
fig = go.Figure()
fig.add_annotation(
text="Trade to see price history",
xref="paper", yref="paper",
showarrow=False,
font=dict(size=16, color="gray")
)
fig.update_layout(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
template="plotly_white",
height=400
)
return fig
df = pd.DataFrame(history)
fig = go.Figure()
colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
for i, stock in enumerate(env.stocks):
fig.add_trace(go.Scatter(
x=df['day'],
y=df[stock],
mode='lines+markers',
name=stock,
line=dict(color=colors[i % len(colors)], width=2),
marker=dict(size=6)
))
fig.update_layout(
title="Stock Price History",
xaxis_title="Day",
yaxis_title="Price ($)",
template="plotly_white",
legend=dict(title="Stocks", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
height=400 + 50 * len(env.stocks)
)
return fig
def create_value_chart():
if len(history) <= 1:
fig, ax = plt.subplots(figsize=(10, 6))
ax.text(0.5, 0.5, 'Trade to see portfolio value', ha='center', va='center', fontsize=14, color='gray')
ax.axis('off')
return fig
df = pd.DataFrame(history)
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(df['day'], df['total_value'], marker='o', linewidth=3, color='#8b5cf6', label='Portfolio Value')
ax.fill_between(df['day'], df['total_value'], alpha=0.2, color='#8b5cf6')
initial_value = history[0]['total_value']
ax.axhline(y=initial_value, color='red', linestyle='--', alpha=0.5, label=f'Initial: ${initial_value:.2f}')
ax.legend()
ax.grid(True, alpha=0.3)
return fig
def execute_trade(stock, action, amount):
global env, history
if env is None:
return "โŒ Please initialize the environment first", None, None, None, None
try:
amount = int(amount)
if amount <= 0:
return "โŒ Amount must be positive", None, None, None, None
if action == "Buy":
idx = env.stocks.index(stock)
cost = env.prices[idx] * amount
if cost > env.cash:
return f"โŒ Insufficient cash!\nNeed: ${cost:.2f}\nHave: ${env.cash:.2f}", None, None, None, None
env.positions[idx] += amount
env.cash -= cost
status = f"โœ… Bought {amount} {stock} at ${env.prices[idx]:.2f}"
else:
idx = env.stocks.index(stock)
qty = min(amount, env.positions[idx])
if qty == 0:
return "โŒ No shares to sell", None, None, None, None
revenue = env.prices[idx] * qty
env.positions[idx] -= qty
env.cash += revenue
status = f"โœ… Sold {qty} {stock} at ${env.prices[idx]:.2f}"
obs = env._get_observation()
status += f"\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f} | Total Value: ${obs['total_value']:.2f}"
return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart()
except Exception as e:
return f"โŒ Error: {str(e)}", None, None, None, None
def advance_day():
global env, history
if env is None:
return "โŒ Please initialize the environment first", None, None, None, None
obs, reward, done, _ = env.step({"buy": {}, "sell": {}})
history.append({'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']})
if done:
init_val = history[0]['total_value']
profit = obs['total_value'] - init_val
pct = profit / init_val * 100
status = f"๐Ÿ Finished! Final value: ${obs['total_value']:.2f}\nProfit: {profit:+.2f} ({pct:+.2f}%)"
else:
status = f"โœ… Advanced to Day {obs['day']} | ๐Ÿ’ฐ Cash: ${obs['cash']:.2f} | ๐Ÿ“Š Value: ${obs['total_value']:.2f}"
return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart()
def reset_env():
global env, history
if env is None:
return initialize_env()
obs = env.reset()
history = [{'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']}]
status = f"๐Ÿ”„ Environment Reset! Day {obs['day']}, Cash ${obs['cash']:.2f}"
return status, create_portfolio_display(obs), create_news_display(obs), None, None
# ======== UI้ƒจๅˆ† ========
custom_css = """
.gradio-container { font-family: 'Arial', sans-serif; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AI Trading Arena") as demo:
gr.Markdown("# ๐Ÿš€ AI Trading Arena\n### Interactive Stock Trading Simulator")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## ๐ŸŽฎ Control Panel")
with gr.Accordion("๐Ÿ“ Configuration", open=False):
gr.Markdown("Select a config file from `/config` folder or paste custom JSON")
config_file_dropdown = gr.Dropdown(
choices=list_config_files(),
label="Choose Config File",
value=list_config_files()[0] if list_config_files() else None
)
load_file_btn = gr.Button("๐Ÿ“‚ Load from File", variant="secondary")
config_input = gr.Textbox(
label="Custom Config JSON",
placeholder='{"num_days": 30, "stocks": ["TECH", "ENERGY"], ...}',
lines=4
)
init_btn = gr.Button("๐Ÿš€ Load Config", variant="primary", size="lg")
with gr.Accordion("๐Ÿ’น Trading Operations", open=True):
stock_dropdown = gr.Dropdown(
choices=["S0", "S1", "S2", "S3", "S4", "S5"],
label="Select Stock",
value="S0"
)
action_radio = gr.Radio(choices=["Buy", "Sell"], label="Action", value="Buy")
amount_input = gr.Number(label="Amount (shares)", value=10, minimum=1, step=1)
trade_btn = gr.Button("๐Ÿ“ˆ Execute Trade", variant="primary", size="lg")
with gr.Row():
advance_btn = gr.Button("โญ๏ธ Next Day", variant="primary", size="lg")
reset_btn = gr.Button("๐Ÿ”„ Reset", variant="secondary", size="lg")
status_output = gr.Textbox(label="๐Ÿ“Š Status & Messages", lines=8, interactive=False)
with gr.Column(scale=2):
gr.Markdown("## ๐Ÿ“Š Market Dashboard")
portfolio_table = gr.Dataframe(label="๐Ÿ’ผ Portfolio Holdings", interactive=False)
news_display = gr.HTML(label="๐Ÿ“ฐ Market News")
with gr.Tab("๐Ÿ“ˆ Price History"):
price_chart = gr.Plot(label="Stock Prices Over Time")
with gr.Tab("๐Ÿ’ฐ Portfolio Value"):
value_chart = gr.Plot(label="Total Portfolio Value")
# ็ป‘ๅฎš้€ป่พ‘
load_file_btn.click(fn=load_config_from_file, inputs=[config_file_dropdown], outputs=[config_input])
init_btn.click(fn=initialize_env, inputs=[config_input],
outputs=[status_output, portfolio_table, news_display, price_chart, value_chart])
reset_btn.click(fn=reset_env, inputs=[],
outputs=[status_output, portfolio_table, news_display, price_chart, value_chart])
trade_btn.click(fn=execute_trade, inputs=[stock_dropdown, action_radio, amount_input],
outputs=[status_output, portfolio_table, news_display, price_chart, value_chart])
advance_btn.click(fn=advance_day, inputs=[],
outputs=[status_output, portfolio_table, news_display, price_chart, value_chart])
demo.load(fn=initialize_env, inputs=[],
outputs=[status_output, portfolio_table, news_display, price_chart, value_chart])
if __name__ == "__main__":
demo.launch()