Fangzhi Xu commited on
Commit
11353e3
·
1 Parent(s): 5b6956f
Files changed (1) hide show
  1. app.py +151 -289
app.py CHANGED
@@ -1,155 +1,83 @@
1
  import os
2
- import gradio as gr
3
- import numpy as np
4
  import json
5
- import pandas as pd
 
6
  import matplotlib.pyplot as plt
7
- import matplotlib
8
- matplotlib.use('Agg')
9
-
10
- # =========================================
11
- # ======= Environment Core Definition ======
12
- # =========================================
13
-
14
- class TradeArenaEnv_Deterministic:
15
- """
16
- Odyssey Arena - AI Trading Environment (Deterministic version)
17
- """
18
- def __init__(self, cfg):
19
- self.num_days = cfg["num_days"]
20
- self.stocks = cfg["stocks"]
21
- self.variables = cfg["variables"]
22
- self.dependency_matrix = np.array(cfg["dependency_matrix"])
23
- self.initial_prices = np.array(cfg["initial_prices"])
24
- self.initial_variables = np.array(cfg["initial_variables"])
25
- self.timeline = cfg["timeline"]
26
- self.price_noise_scale = cfg.get("price_noise_scale", 0.0)
27
- self.initial_cash = cfg.get("initial_cash", 10000.0)
28
  self.reset()
29
 
30
  def reset(self):
31
- self.t = 0
32
- self.cash = self.initial_cash
33
- self.positions = np.zeros(len(self.stocks), dtype=np.float64)
34
- self.prices = self.initial_prices.copy().astype(np.float64)
35
- self.variables_state = self.initial_variables.copy().astype(np.float64)
36
- self.next_day_news = self.timeline.get("day_1", None)
37
- return self._get_observation()
38
-
39
- def _get_observation(self):
40
- obs = {
41
- "day": self.t,
42
- "prices": {s: float(p) for s, p in zip(self.stocks, self.prices)},
43
- "cash": float(self.cash),
44
- "positions": {s: int(pos) for s, pos in zip(self.stocks, self.positions)},
45
- "total_value": float(self.cash + np.sum(self.positions * self.prices)),
46
- "news_next_day": self.next_day_news["variable_changes"] if self.next_day_news else None,
47
- "news_next_day_text": self.next_day_news["news_text"] if self.next_day_news else None
48
- }
49
- return obs
50
 
51
  def step(self, action):
52
- assert isinstance(action, dict)
53
-
54
- # Execute sells first
55
- for stock, qty in action.get("sell", {}).items():
56
- idx = self.stocks.index(stock)
57
- qty = int(qty)
58
- qty = min(qty, self.positions[idx])
59
- revenue = self.prices[idx] * qty
60
- self.positions[idx] -= qty
61
- self.cash += revenue
62
-
63
- # Then buys
64
- for stock, qty in action.get("buy", {}).items():
65
- idx = self.stocks.index(stock)
66
- qty = int(qty)
67
- cost = self.prices[idx] * qty
68
- if cost <= self.cash:
69
- self.positions[idx] += qty
70
- self.cash -= cost
71
-
72
- # Advance one day
73
- self.t += 1
74
- done = self.t >= self.num_days
75
-
76
- # Update variable states & prices
77
- if not done:
78
- news_today = self.timeline.get(f"day_{self.t}", None)
79
- if news_today:
80
- deltas = np.array(news_today["variable_changes"])
81
- self.variables_state += deltas
82
- self._update_prices_from_variables(deltas)
83
-
84
- # Prepare next day's news
85
- self.next_day_news = self.timeline.get(f"day_{self.t + 1}", None) if not done else None
86
-
87
  reward = self._compute_reward()
88
- obs = self._get_observation()
89
  return obs, reward, done, {}
90
 
91
- def _update_prices_from_variables(self, delta_vars):
92
- delta_price = self.dependency_matrix @ delta_vars
93
- noise = np.zeros_like(delta_price) if self.price_noise_scale == 0 else np.random.normal(
94
- 0, self.price_noise_scale, len(self.stocks)
95
- )
96
- self.prices += delta_price + noise
97
- self.prices = np.clip(self.prices, 0.1, None)
 
 
 
 
 
98
 
99
  def _compute_reward(self):
100
- total_value = self.cash + np.sum(self.positions * self.prices)
101
- return round(float(total_value), 2)
102
-
103
-
104
- # =========================================
105
- # =========== Default Config ==============
106
- # =========================================
107
-
108
- DEFAULT_CONFIG = {
109
- "num_days": 30,
110
- "stocks": ["TECH", "ENERGY", "FINANCE"],
111
- "variables": ["interest_rate", "oil_price", "market_sentiment"],
112
- "dependency_matrix": [
113
- [-5, 2, 3],
114
- [1, 8, 2],
115
- [-3, 1, 4]
116
- ],
117
- "initial_prices": [100, 80, 120],
118
- "initial_variables": [0, 0, 0],
119
- "initial_cash": 10000,
120
- "price_noise_scale": 0,
121
- "timeline": {
122
- "day_1": {"variable_changes": [0.1, -0.2, 0.3],
123
- "news_text": "Federal Reserve hints at rate increase; Oil prices drop on oversupply concerns"},
124
- "day_2": {"variable_changes": [-0.1, 0.3, 0.2],
125
- "news_text": "Tech sector shows strong earnings; Energy stocks rally on production cuts"},
126
- "day_3": {"variable_changes": [0.2, 0.1, -0.1],
127
- "news_text": "Market sentiment cautious amid geopolitical tensions"},
128
- "day_4": {"variable_changes": [0.0, 0.2, 0.1],
129
- "news_text": "Stable interest rates; Energy sector momentum continues"},
130
- "day_5": {"variable_changes": [-0.2, -0.1, 0.0],
131
- "news_text": "Rate cut speculation; Market consolidation"}
132
- }
133
- }
134
-
135
- # =========================================
136
- # =========== Global State ================
137
- # =========================================
138
 
 
 
 
139
  env = None
140
  history = []
141
 
142
 
143
- # =========================================
144
- # =========== Utility Functions ===========
145
- # =========================================
146
-
147
  def list_config_files():
148
  config_dir = "config"
149
  if not os.path.exists(config_dir):
150
  return []
151
  return [f for f in os.listdir(config_dir) if f.endswith(".json")]
152
 
 
153
  def load_config_from_file(filename):
154
  try:
155
  path = os.path.join("config", filename)
@@ -160,180 +88,114 @@ def load_config_from_file(filename):
160
  return f"❌ Error reading {filename}: {str(e)}"
161
 
162
 
163
- # =========================================
164
- # ============ Core Logic =================
165
- # =========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- def initialize_env(config_file=None):
 
 
 
 
168
  global env, history
169
-
170
- if config_file is not None and config_file.strip():
171
- try:
172
- config = json.loads(config_file)
173
- except:
174
- return "❌ Invalid JSON file", None, None, None, None, gr.update()
175
- else:
176
- config = DEFAULT_CONFIG
177
-
178
- env = TradeArenaEnv_Deterministic(config)
179
  obs = env.reset()
180
-
181
- history = [{
182
- 'day': obs['day'],
183
- 'total_value': obs['total_value'],
184
- **obs['prices']
185
- }]
186
-
187
- status = f"✅ Session initialized!\n📅 Day: {obs['day']}\n💰 Cash: ${obs['cash']:.2f}\n📊 Total Value: ${obs['total_value']:.2f}"
188
-
189
  return (
190
- status,
191
- create_portfolio_display(obs),
192
- create_news_display(obs),
193
  create_price_chart(),
194
- create_value_chart(),
195
  gr.update(choices=env.stocks, value=env.stocks[0])
196
  )
197
 
198
- def create_portfolio_display(obs):
199
- data = []
200
- for stock in env.stocks:
201
- data.append({
202
- 'Stock': stock,
203
- 'Price': f"${obs['prices'][stock]:.2f}",
204
- 'Holdings': obs['positions'][stock],
205
- 'Value': f"${obs['prices'][stock] * obs['positions'][stock]:.2f}"
206
- })
207
- return pd.DataFrame(data)
208
-
209
- def create_news_display(obs):
210
- if obs['news_next_day_text']:
211
- news_html = f"""
212
- <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
213
- padding: 20px; border-radius: 10px; color: white; margin: 10px 0;'>
214
- <h3 style='margin-top: 0;'>📰 Next Day News</h3>
215
- <p style='font-size: 16px; line-height: 1.6;'>{obs['news_next_day_text']}</p>
216
- """
217
- if obs['news_next_day']:
218
- news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b><br/>"
219
- for i, var in enumerate(env.variables):
220
- change = obs['news_next_day'][i]
221
- news_html += f"• {var}: <b>{'+' if change > 0 else ''}{change}</b><br/>"
222
- news_html += "</p>"
223
- news_html += "</div>"
224
- return news_html
225
- else:
226
- return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px; text-align: center;'>📭 No more news available</div>"
227
 
228
- def create_price_chart():
229
- if len(history) <= 1:
230
- fig, ax = plt.subplots(figsize=(10, 6))
231
- ax.text(0.5, 0.5, 'Trade to see price history', ha='center', va='center', fontsize=14, color='gray')
232
- ax.axis('off')
233
- return fig
234
-
235
- df = pd.DataFrame(history)
236
- fig, ax = plt.subplots(figsize=(10, 6))
237
- colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
238
-
239
- for i, stock in enumerate(env.stocks):
240
- ax.plot(df['day'], df[stock], marker='o', linewidth=2,
241
- color=colors[i % len(colors)], label=stock)
242
-
243
- ax.set_xlabel('Day', fontsize=12, fontweight='bold')
244
- ax.set_ylabel('Price ($)', fontsize=12, fontweight='bold')
245
- ax.set_title('Stock Price History', fontsize=14, fontweight='bold', pad=20)
246
- ax.legend(loc='best', framealpha=0.9)
247
- ax.grid(True, alpha=0.3)
248
- ax.set_facecolor('#f8f9fa')
249
- fig.patch.set_facecolor('white')
250
- plt.tight_layout()
251
- return fig
252
 
253
- def create_value_chart():
254
- if len(history) <= 1:
255
- fig, ax = plt.subplots(figsize=(10, 6))
256
- ax.text(0.5, 0.5, 'Trade to see portfolio value', ha='center', va='center', fontsize=14, color='gray')
257
- ax.axis('off')
258
- return fig
259
-
260
- df = pd.DataFrame(history)
261
- fig, ax = plt.subplots(figsize=(10, 6))
262
- ax.plot(df['day'], df['total_value'], marker='o', linewidth=3, color='#8b5cf6', label='Portfolio Value')
263
- ax.fill_between(df['day'], df['total_value'], alpha=0.2, color='#8b5cf6')
264
- initial_value = history[0]['total_value']
265
- ax.axhline(y=initial_value, color='red', linestyle='--', alpha=0.5, label=f'Initial: ${initial_value:.2f}')
266
- ax.legend(loc='best', framealpha=0.9)
267
- ax.set_xlabel('Day', fontsize=12, fontweight='bold')
268
- ax.set_ylabel('Total Value ($)', fontsize=12, fontweight='bold')
269
- ax.set_title('Portfolio Value Over Time', fontsize=14, fontweight='bold', pad=20)
270
- ax.grid(True, alpha=0.3)
271
- plt.tight_layout()
272
- return fig
273
 
274
 
275
- # =========================================
276
- # ============ UI Definition ==============
277
- # =========================================
 
 
 
278
 
279
- custom_css = """
280
- .gradio-container { font-family: 'Arial', sans-serif; }
281
- .gr-button-primary {
282
- background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
283
- border: none !important;
284
- }
285
- .gr-button-secondary {
286
- background: linear-gradient(90deg, #f093fb 0%, #f5576c 100%) !important;
287
- border: none !important;
288
- }
289
- """
290
 
291
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AI Trading Arena") as demo:
292
- gr.Markdown("# 🚀 AI Trading Arena\n### Interactive Stock Trading Simulator")
293
 
294
  with gr.Row():
295
- with gr.Column(scale=1):
296
- gr.Markdown("## 🎮 Control Panel")
297
-
298
- with gr.Accordion("📁 Configuration", open=False):
299
- config_file_dropdown = gr.Dropdown(
300
- choices=list_config_files(),
301
- label="Choose Config File from /config",
302
- value=list_config_files()[0] if list_config_files() else None
303
- )
304
- load_file_btn = gr.Button("📂 Load from File", variant="secondary")
305
- config_input = gr.Textbox(
306
- label="Custom Config JSON",
307
- placeholder='{"num_days": 30, "stocks": ["TECH", "ENERGY"], ...}',
308
- lines=4
309
- )
310
- init_btn = gr.Button("🚀 Initialize Environment", variant="primary", size="lg")
311
-
312
- with gr.Accordion("💹 Trading Operations", open=True):
313
- stock_dropdown = gr.Dropdown(choices=DEFAULT_CONFIG["stocks"], label="Select Stock", value=DEFAULT_CONFIG["stocks"][0])
314
- action_radio = gr.Radio(choices=["Buy", "Sell"], label="Action", value="Buy")
315
- amount_input = gr.Number(label="Amount (shares)", value=10, minimum=1, step=1)
316
- trade_btn = gr.Button("📈 Execute Trade", variant="primary", size="lg")
317
-
318
- with gr.Row():
319
- advance_btn = gr.Button("⏭️ Next Day", variant="primary", size="lg")
320
- reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg")
321
-
322
- status_output = gr.Textbox(label="📊 Status & Messages", lines=8, interactive=False)
323
-
324
- with gr.Column(scale=2):
325
- gr.Markdown("## 📊 Market Dashboard")
326
- portfolio_table = gr.Dataframe(label="💼 Portfolio Holdings", interactive=False)
327
- news_display = gr.HTML(label="📰 Market News")
328
- with gr.Tab("📈 Price History"): price_chart = gr.Plot(label="Stock Prices Over Time")
329
- with gr.Tab("💰 Portfolio Value"): value_chart = gr.Plot(label="Total Portfolio Value")
330
-
331
- # === Button Bindings ===
332
- load_file_btn.click(fn=load_config_from_file, inputs=[config_file_dropdown], outputs=[config_input])
333
- init_btn.click(fn=initialize_env, inputs=[config_input],
334
- outputs=[status_output, portfolio_table, news_display, price_chart, value_chart, stock_dropdown])
335
-
336
- demo.load(fn=initialize_env, inputs=[], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart, stock_dropdown])
337
-
338
- if __name__ == "__main__":
339
- demo.launch()
 
1
  import os
 
 
2
  import json
3
+ import random
4
+ import gradio as gr
5
  import matplotlib.pyplot as plt
6
+
7
+ # ======================
8
+ # Environment Definition
9
+ # ======================
10
+ class MarketEnv:
11
+ def __init__(self, config):
12
+ self.config = config
13
+ self.stocks = config.get("stocks", ["AAPL", "GOOG", "TSLA", "AMZN"])
14
+ self.num_days = config.get("num_days", 30)
 
 
 
 
 
 
 
 
 
 
 
 
15
  self.reset()
16
 
17
  def reset(self):
18
+ self.current_day = 0
19
+ self.cash = 10000
20
+ self.portfolio = {s: 0 for s in self.stocks}
21
+ self.prices = {s: [random.uniform(50, 150)] for s in self.stocks}
22
+ self.generate_next_day()
23
+ return self._get_obs()
24
+
25
+ def generate_next_day(self):
26
+ if self.current_day >= self.num_days:
27
+ return
28
+ for s in self.stocks:
29
+ last_price = self.prices[s][-1]
30
+ change = random.uniform(-0.05, 0.05)
31
+ new_price = max(1, last_price * (1 + change))
32
+ self.prices[s].append(new_price)
 
 
 
 
33
 
34
  def step(self, action):
35
+ self.current_day += 1
36
+ self.generate_next_day()
37
+ obs = self._get_obs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  reward = self._compute_reward()
39
+ done = self.current_day >= self.num_days
40
  return obs, reward, done, {}
41
 
42
+ def buy(self, stock, amount):
43
+ price = self.prices[stock][-1]
44
+ cost = price * amount
45
+ if self.cash >= cost:
46
+ self.cash -= cost
47
+ self.portfolio[stock] += amount
48
+
49
+ def sell(self, stock, amount):
50
+ if self.portfolio[stock] >= amount:
51
+ price = self.prices[stock][-1]
52
+ self.cash += price * amount
53
+ self.portfolio[stock] -= amount
54
 
55
  def _compute_reward(self):
56
+ total_value = self.cash + sum(self.prices[s][-1] * self.portfolio[s] for s in self.stocks)
57
+ return total_value
58
+
59
+ def _get_obs(self):
60
+ prices_today = {s: self.prices[s][-1] for s in self.stocks}
61
+ return {"day": self.current_day, "cash": self.cash, "portfolio": self.portfolio, "prices": prices_today}
62
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # ==============
65
+ # Global Objects
66
+ # ==============
67
  env = None
68
  history = []
69
 
70
 
71
+ # ==============
72
+ # Config Helpers
73
+ # ==============
 
74
  def list_config_files():
75
  config_dir = "config"
76
  if not os.path.exists(config_dir):
77
  return []
78
  return [f for f in os.listdir(config_dir) if f.endswith(".json")]
79
 
80
+
81
  def load_config_from_file(filename):
82
  try:
83
  path = os.path.join("config", filename)
 
88
  return f"❌ Error reading {filename}: {str(e)}"
89
 
90
 
91
+ # ==========================
92
+ # Visualization Helper Tools
93
+ # ==========================
94
+ def create_price_chart():
95
+ fig, ax = plt.subplots()
96
+ for s in env.stocks:
97
+ ax.plot(env.prices[s], label=s)
98
+ ax.legend()
99
+ ax.set_title("Stock Prices")
100
+ ax.set_xlabel("Days")
101
+ ax.set_ylabel("Price")
102
+ return fig
103
+
104
+
105
+ def create_portfolio_table():
106
+ table_md = "| Stock | Holdings | Price | Value |\n|:--|:--|:--|:--|\n"
107
+ for s in env.stocks:
108
+ price = env.prices[s][-1]
109
+ qty = env.portfolio[s]
110
+ table_md += f"| {s} | {qty} | {price:.2f} | {qty * price:.2f} |\n"
111
+ return table_md
112
 
113
+
114
+ # =====================
115
+ # Gradio Event Handlers
116
+ # =====================
117
+ def initialize_env(config_json):
118
  global env, history
119
+ try:
120
+ cfg = json.loads(config_json)
121
+ except:
122
+ cfg = {"stocks": ["AAPL", "GOOG", "TSLA", "AMZN"], "num_days": 30}
123
+
124
+ env = MarketEnv(cfg)
125
+ history = []
 
 
 
126
  obs = env.reset()
 
 
 
 
 
 
 
 
 
127
  return (
128
+ f"✅ Environment initialized with {len(env.stocks)} stocks.",
129
+ create_portfolio_table(),
 
130
  create_price_chart(),
 
131
  gr.update(choices=env.stocks, value=env.stocks[0])
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ def act_buy(stock, amount):
136
+ global env
137
+ if not env:
138
+ return "❌ Please initialize environment first.", create_portfolio_table(), create_price_chart()
139
+ env.buy(stock, int(amount))
140
+ obs, reward, done, _ = env.step(None)
141
+ history.append(reward)
142
+ return (
143
+ f"✅ Bought {amount} of {stock}",
144
+ create_portfolio_table(),
145
+ create_price_chart()
146
+ )
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+
149
+ def act_sell(stock, amount):
150
+ global env
151
+ if not env:
152
+ return "❌ Please initialize environment first.", create_portfolio_table(), create_price_chart()
153
+ env.sell(stock, int(amount))
154
+ obs, reward, done, _ = env.step(None)
155
+ history.append(reward)
156
+ return (
157
+ f"✅ Sold {amount} of {stock}",
158
+ create_portfolio_table(),
159
+ create_price_chart()
160
+ )
 
 
 
 
 
 
 
161
 
162
 
163
+ # ==============
164
+ # Gradio UI
165
+ # ==============
166
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Arena") as demo:
167
+ gr.Markdown("# 💹 AI Trading Arena")
168
+ status_output = gr.Markdown("👋 Ready to start your trading journey!")
169
 
170
+ with gr.Accordion("📁 Configuration", open=False):
171
+ config_file_dropdown = gr.Dropdown(
172
+ choices=list_config_files(),
173
+ label="Choose Config File",
174
+ value=list_config_files()[0] if list_config_files() else None
175
+ )
176
+ load_file_btn = gr.Button("📂 Load from File", variant="secondary")
177
+ config_input = gr.Textbox(label="Custom Config JSON", lines=4)
178
+ init_btn = gr.Button("🚀 Initialize Environment", variant="primary")
 
 
179
 
180
+ portfolio_table = gr.Markdown()
181
+ price_chart = gr.Plot()
182
 
183
  with gr.Row():
184
+ stock_dropdown = gr.Dropdown(choices=["AAPL", "GOOG"], label="Stock")
185
+ amount_slider = gr.Slider(1, 10, step=1, label="Amount")
186
+ with gr.Row():
187
+ buy_btn = gr.Button("🟢 Buy")
188
+ sell_btn = gr.Button("🔴 Sell")
189
+
190
+ load_file_btn.click(load_config_from_file, [config_file_dropdown], [config_input])
191
+
192
+ init_btn.click(
193
+ initialize_env,
194
+ inputs=[config_input],
195
+ outputs=[status_output, portfolio_table, price_chart, stock_dropdown]
196
+ )
197
+
198
+ buy_btn.click(act_buy, [stock_dropdown, amount_slider], [status_output, portfolio_table, price_chart])
199
+ sell_btn.click(act_sell, [stock_dropdown, amount_slider], [status_output, portfolio_table, price_chart])
200
+
201
+ demo.launch()