xufangzhi commited on
Commit
df848f6
ยท
verified ยท
1 Parent(s): 2d77f57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -81
app.py CHANGED
@@ -2,8 +2,9 @@ import gradio as gr
2
  import numpy as np
3
  import json
4
  import pandas as pd
5
- import plotly.graph_objects as go
6
- from plotly.subplots import make_subplots
 
7
 
8
  class TradeArenaEnv_Deterministic:
9
  """
@@ -140,7 +141,7 @@ history = []
140
  def initialize_env(config_file=None):
141
  global env, history
142
 
143
- if config_file is not None:
144
  try:
145
  config = json.loads(config_file)
146
  except:
@@ -187,76 +188,80 @@ def create_news_display(obs):
187
  if obs['news_next_day_text']:
188
  news_html = f"""
189
  <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
190
- padding: 20px; border-radius: 10px; color: white;'>
191
- <h3>๐Ÿ“ฐ Next Day News</h3>
192
- <p style='font-size: 16px;'>{obs['news_next_day_text']}</p>
193
  """
194
  if obs['news_next_day']:
195
- news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b> "
196
  for i, var in enumerate(env.variables):
197
  change = obs['news_next_day'][i]
198
- news_html += f"{var}: {'+' if change > 0 else ''}{change} | "
199
  news_html += "</p>"
200
  news_html += "</div>"
201
  return news_html
202
  else:
203
- return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px;'>No more news available</div>"
204
 
205
  def create_price_chart():
206
- """Create price history chart"""
207
  if len(history) <= 1:
208
- return None
 
 
 
 
209
 
210
  df = pd.DataFrame(history)
211
 
212
- fig = go.Figure()
213
  colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
214
 
215
  for i, stock in enumerate(env.stocks):
216
- fig.add_trace(go.Scatter(
217
- x=df['day'],
218
- y=df[stock],
219
- mode='lines+markers',
220
- name=stock,
221
- line=dict(width=3, color=colors[i % len(colors)])
222
- ))
223
-
224
- fig.update_layout(
225
- title='Stock Price History',
226
- xaxis_title='Day',
227
- yaxis_title='Price ($)',
228
- hovermode='x unified',
229
- template='plotly_white',
230
- height=400
231
- )
232
 
233
  return fig
234
 
235
  def create_value_chart():
236
- """Create portfolio value chart"""
237
  if len(history) <= 1:
238
- return None
 
 
 
 
239
 
240
  df = pd.DataFrame(history)
241
 
242
- fig = go.Figure()
243
- fig.add_trace(go.Scatter(
244
- x=df['day'],
245
- y=df['total_value'],
246
- mode='lines+markers',
247
- name='Portfolio Value',
248
- line=dict(width=4, color='#8b5cf6'),
249
- fill='tozeroy',
250
- fillcolor='rgba(139, 92, 246, 0.1)'
251
- ))
252
-
253
- fig.update_layout(
254
- title='Portfolio Value Over Time',
255
- xaxis_title='Day',
256
- yaxis_title='Total Value ($)',
257
- template='plotly_white',
258
- height=400
259
- )
 
260
 
261
  return fig
262
 
@@ -286,18 +291,18 @@ def execute_trade(stock, action, amount):
286
  revenue = env.prices[idx] * qty
287
  env.positions[idx] -= qty
288
  env.cash += revenue
289
- status = f"โœ… Sold {qty} shares of {stock} at ${env.prices[idx]:.2f}"
290
  else: # Buy
291
  idx = env.stocks.index(stock)
292
  cost = env.prices[idx] * amount
293
  if cost > env.cash:
294
- return f"โŒ Insufficient cash! Need ${cost:.2f}, have ${env.cash:.2f}", None, None, None, None
295
  env.positions[idx] += amount
296
  env.cash -= cost
297
- status = f"โœ… Bought {amount} shares of {stock} at ${env.prices[idx]:.2f}"
298
 
299
  obs = env._get_observation()
300
- status += f"\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
301
 
302
  return (
303
  status,
@@ -328,9 +333,18 @@ def advance_day():
328
  })
329
 
330
  if done:
331
- status = f"๐Ÿ Simulation Complete!\n๐Ÿ“… Final Day: {obs['day']}\n๐Ÿ’ฐ Final Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Final Value: ${obs['total_value']:.2f}"
 
 
 
 
 
 
 
332
  else:
333
- status = f"โœ… Advanced to Day {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
 
 
334
 
335
  return (
336
  status,
@@ -357,7 +371,7 @@ def reset_env():
357
  **obs['prices']
358
  }]
359
 
360
- status = f"๐Ÿ”„ Environment Reset!\n๐Ÿ“… Day: {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
361
 
362
  return (
363
  status,
@@ -367,13 +381,28 @@ def reset_env():
367
  None
368
  )
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  # Create Gradio Interface
371
- with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Arena") as demo:
372
  gr.Markdown(
373
  """
374
  # ๐Ÿš€ AI Trading Arena
375
  ### Interactive Stock Trading Simulator
376
- Upload your config or use the default configuration to start trading!
377
  """
378
  )
379
 
@@ -381,16 +410,16 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Arena") as demo:
381
  with gr.Column(scale=1):
382
  gr.Markdown("## ๐ŸŽฎ Control Panel")
383
 
384
- with gr.Accordion("๐Ÿ“ Configuration", open=True):
 
385
  config_input = gr.Textbox(
386
- label="Upload Config JSON (or leave empty for default)",
387
- placeholder='Paste JSON config here or leave empty',
388
- lines=5
389
  )
390
- init_btn = gr.Button("๐Ÿš€ Initialize/Load Config", variant="primary", size="lg")
391
- reset_btn = gr.Button("๐Ÿ”„ Reset Environment", variant="secondary")
392
 
393
- with gr.Accordion("๐Ÿ’น Trading", open=True):
394
  stock_dropdown = gr.Dropdown(
395
  choices=DEFAULT_CONFIG["stocks"],
396
  label="Select Stock",
@@ -403,35 +432,54 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Arena") as demo:
403
  )
404
  amount_input = gr.Number(
405
  label="Amount (shares)",
406
- value=1,
407
- minimum=1
 
408
  )
409
- trade_btn = gr.Button("๐Ÿ“ˆ Execute Trade", variant="primary")
 
 
410
 
411
- advance_btn = gr.Button("โญ๏ธ Advance to Next Day", variant="primary", size="lg")
 
 
412
 
413
  status_output = gr.Textbox(
414
- label="๐Ÿ“Š Status",
415
- lines=5,
416
- interactive=False
 
417
  )
418
 
419
  with gr.Column(scale=2):
420
- gr.Markdown("## ๐Ÿ“Š Market Overview")
421
 
422
  portfolio_table = gr.Dataframe(
423
- label="Portfolio Holdings",
424
- interactive=False
 
425
  )
426
 
427
- news_display = gr.HTML(label="News")
 
 
 
428
 
429
- with gr.Tabs():
430
- with gr.Tab("๐Ÿ“ˆ Price History"):
431
- price_chart = gr.Plot(label="Stock Prices")
432
-
433
- with gr.Tab("๐Ÿ’ฐ Portfolio Value"):
434
- value_chart = gr.Plot(label="Total Value")
 
 
 
 
 
 
 
 
 
435
 
436
  # Event handlers
437
  init_btn.click(
 
2
  import numpy as np
3
  import json
4
  import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib
7
+ matplotlib.use('Agg')
8
 
9
  class TradeArenaEnv_Deterministic:
10
  """
 
141
  def initialize_env(config_file=None):
142
  global env, history
143
 
144
+ if config_file is not None and config_file.strip():
145
  try:
146
  config = json.loads(config_file)
147
  except:
 
188
  if obs['news_next_day_text']:
189
  news_html = f"""
190
  <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
191
+ padding: 20px; border-radius: 10px; color: white; margin: 10px 0;'>
192
+ <h3 style='margin-top: 0;'>๐Ÿ“ฐ Next Day News</h3>
193
+ <p style='font-size: 16px; line-height: 1.6;'>{obs['news_next_day_text']}</p>
194
  """
195
  if obs['news_next_day']:
196
+ news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b><br/>"
197
  for i, var in enumerate(env.variables):
198
  change = obs['news_next_day'][i]
199
+ news_html += f"โ€ข {var}: <b>{'+' if change > 0 else ''}{change}</b><br/>"
200
  news_html += "</p>"
201
  news_html += "</div>"
202
  return news_html
203
  else:
204
+ return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px; text-align: center;'>๐Ÿ“ญ No more news available</div>"
205
 
206
  def create_price_chart():
207
+ """Create price history chart using matplotlib"""
208
  if len(history) <= 1:
209
+ fig, ax = plt.subplots(figsize=(10, 6))
210
+ ax.text(0.5, 0.5, 'Trade to see price history',
211
+ ha='center', va='center', fontsize=14, color='gray')
212
+ ax.axis('off')
213
+ return fig
214
 
215
  df = pd.DataFrame(history)
216
 
217
+ fig, ax = plt.subplots(figsize=(10, 6))
218
  colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
219
 
220
  for i, stock in enumerate(env.stocks):
221
+ ax.plot(df['day'], df[stock], marker='o', linewidth=2,
222
+ color=colors[i % len(colors)], label=stock)
223
+
224
+ ax.set_xlabel('Day', fontsize=12, fontweight='bold')
225
+ ax.set_ylabel('Price ($)', fontsize=12, fontweight='bold')
226
+ ax.set_title('Stock Price History', fontsize=14, fontweight='bold', pad=20)
227
+ ax.legend(loc='best', framealpha=0.9)
228
+ ax.grid(True, alpha=0.3)
229
+ ax.set_facecolor('#f8f9fa')
230
+ fig.patch.set_facecolor('white')
231
+ plt.tight_layout()
 
 
 
 
 
232
 
233
  return fig
234
 
235
  def create_value_chart():
236
+ """Create portfolio value chart using matplotlib"""
237
  if len(history) <= 1:
238
+ fig, ax = plt.subplots(figsize=(10, 6))
239
+ ax.text(0.5, 0.5, 'Trade to see portfolio value',
240
+ ha='center', va='center', fontsize=14, color='gray')
241
+ ax.axis('off')
242
+ return fig
243
 
244
  df = pd.DataFrame(history)
245
 
246
+ fig, ax = plt.subplots(figsize=(10, 6))
247
+ ax.plot(df['day'], df['total_value'], marker='o', linewidth=3,
248
+ color='#8b5cf6', label='Portfolio Value')
249
+ ax.fill_between(df['day'], df['total_value'], alpha=0.2, color='#8b5cf6')
250
+
251
+ ax.set_xlabel('Day', fontsize=12, fontweight='bold')
252
+ ax.set_ylabel('Total Value ($)', fontsize=12, fontweight='bold')
253
+ ax.set_title('Portfolio Value Over Time', fontsize=14, fontweight='bold', pad=20)
254
+ ax.legend(loc='best', framealpha=0.9)
255
+ ax.grid(True, alpha=0.3)
256
+ ax.set_facecolor('#f8f9fa')
257
+ fig.patch.set_facecolor('white')
258
+
259
+ # Add initial value line
260
+ initial_value = history[0]['total_value']
261
+ ax.axhline(y=initial_value, color='red', linestyle='--', alpha=0.5, label=f'Initial: ${initial_value:.2f}')
262
+ ax.legend(loc='best', framealpha=0.9)
263
+
264
+ plt.tight_layout()
265
 
266
  return fig
267
 
 
291
  revenue = env.prices[idx] * qty
292
  env.positions[idx] -= qty
293
  env.cash += revenue
294
+ status = f"โœ… Sold {qty} shares of {stock} at ${env.prices[idx]:.2f}\n๐Ÿ’ฐ Revenue: ${revenue:.2f}"
295
  else: # Buy
296
  idx = env.stocks.index(stock)
297
  cost = env.prices[idx] * amount
298
  if cost > env.cash:
299
+ return f"โŒ Insufficient cash!\n๐Ÿ’ต Need: ${cost:.2f}\n๐Ÿ’ฐ Have: ${env.cash:.2f}", None, None, None, None
300
  env.positions[idx] += amount
301
  env.cash -= cost
302
+ status = f"โœ… Bought {amount} shares of {stock} at ${env.prices[idx]:.2f}\n๐Ÿ’ต Cost: ${cost:.2f}"
303
 
304
  obs = env._get_observation()
305
+ status += f"\n\n๐Ÿ“Š Current Status:\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“ˆ Total Value: ${obs['total_value']:.2f}"
306
 
307
  return (
308
  status,
 
333
  })
334
 
335
  if done:
336
+ initial_value = history[0]['total_value']
337
+ profit = obs['total_value'] - initial_value
338
+ profit_pct = (profit / initial_value) * 100
339
+ status = f"๐Ÿ Simulation Complete!\n\n"
340
+ status += f"๐Ÿ“… Final Day: {obs['day']}\n"
341
+ status += f"๐Ÿ’ฐ Final Cash: ${obs['cash']:.2f}\n"
342
+ status += f"๐Ÿ“Š Final Value: ${obs['total_value']:.2f}\n\n"
343
+ status += f"{'๐Ÿ“ˆ' if profit >= 0 else '๐Ÿ“‰'} P&L: ${profit:+.2f} ({profit_pct:+.2f}%)"
344
  else:
345
+ status = f"โœ… Advanced to Day {obs['day']}\n\n"
346
+ status += f"๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n"
347
+ status += f"๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
348
 
349
  return (
350
  status,
 
371
  **obs['prices']
372
  }]
373
 
374
+ status = f"๐Ÿ”„ Environment Reset!\n\n๐Ÿ“… Day: {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
375
 
376
  return (
377
  status,
 
381
  None
382
  )
383
 
384
+ # Custom CSS for better styling
385
+ custom_css = """
386
+ .gradio-container {
387
+ font-family: 'Arial', sans-serif;
388
+ }
389
+ .gr-button-primary {
390
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
391
+ border: none !important;
392
+ }
393
+ .gr-button-secondary {
394
+ background: linear-gradient(90deg, #f093fb 0%, #f5576c 100%) !important;
395
+ border: none !important;
396
+ }
397
+ """
398
+
399
  # Create Gradio Interface
400
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AI Trading Arena") as demo:
401
  gr.Markdown(
402
  """
403
  # ๐Ÿš€ AI Trading Arena
404
  ### Interactive Stock Trading Simulator
405
+ Test your trading strategies in a deterministic market environment!
406
  """
407
  )
408
 
 
410
  with gr.Column(scale=1):
411
  gr.Markdown("## ๐ŸŽฎ Control Panel")
412
 
413
+ with gr.Accordion("๐Ÿ“ Configuration", open=False):
414
+ gr.Markdown("Upload your custom JSON config or leave empty to use default")
415
  config_input = gr.Textbox(
416
+ label="Custom Config JSON",
417
+ placeholder='{"num_days": 30, "stocks": ["TECH", "ENERGY"], ...}',
418
+ lines=3
419
  )
420
+ init_btn = gr.Button("๐Ÿš€ Load Config", variant="primary", size="lg")
 
421
 
422
+ with gr.Accordion("๐Ÿ’น Trading Operations", open=True):
423
  stock_dropdown = gr.Dropdown(
424
  choices=DEFAULT_CONFIG["stocks"],
425
  label="Select Stock",
 
432
  )
433
  amount_input = gr.Number(
434
  label="Amount (shares)",
435
+ value=10,
436
+ minimum=1,
437
+ step=1
438
  )
439
+ trade_btn = gr.Button("๐Ÿ“ˆ Execute Trade", variant="primary", size="lg")
440
+
441
+ gr.Markdown("---")
442
 
443
+ with gr.Row():
444
+ advance_btn = gr.Button("โญ๏ธ Next Day", variant="primary", size="lg")
445
+ reset_btn = gr.Button("๐Ÿ”„ Reset", variant="secondary", size="lg")
446
 
447
  status_output = gr.Textbox(
448
+ label="๐Ÿ“Š Status & Messages",
449
+ lines=8,
450
+ interactive=False,
451
+ show_copy_button=True
452
  )
453
 
454
  with gr.Column(scale=2):
455
+ gr.Markdown("## ๐Ÿ“Š Market Dashboard")
456
 
457
  portfolio_table = gr.Dataframe(
458
+ label="๐Ÿ’ผ Portfolio Holdings",
459
+ interactive=False,
460
+ wrap=True
461
  )
462
 
463
+ news_display = gr.HTML(label="๐Ÿ“ฐ Market News")
464
+
465
+ with gr.Tab("๐Ÿ“ˆ Price History"):
466
+ price_chart = gr.Plot(label="Stock Prices Over Time")
467
 
468
+ with gr.Tab("๐Ÿ’ฐ Portfolio Value"):
469
+ value_chart = gr.Plot(label="Total Portfolio Value")
470
+
471
+ gr.Markdown(
472
+ """
473
+ ---
474
+ ### ๐Ÿ“– How to Use
475
+ 1. **Initialize**: Click "Load Config" or start with default settings
476
+ 2. **Trade**: Select stock, choose Buy/Sell, enter amount, and execute
477
+ 3. **Advance**: Click "Next Day" to see how news affects prices
478
+ 4. **Monitor**: Watch your portfolio value change over time
479
+
480
+ ๐Ÿ’ก **Tip**: Check the news preview to make informed trading decisions!
481
+ """
482
+ )
483
 
484
  # Event handlers
485
  init_btn.click(