xufangzhi commited on
Commit
2d77f57
ยท
verified ยท
1 Parent(s): f1b55c8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +469 -0
app.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import json
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+
8
+ class TradeArenaEnv_Deterministic:
9
+ """
10
+ Odyssey Arena - AI Trading Environment (Deterministic version)
11
+ """
12
+ def __init__(self, cfg):
13
+ self.num_days = cfg["num_days"]
14
+ self.stocks = cfg["stocks"]
15
+ self.variables = cfg["variables"]
16
+ self.dependency_matrix = np.array(cfg["dependency_matrix"])
17
+ self.initial_prices = np.array(cfg["initial_prices"])
18
+ self.initial_variables = np.array(cfg["initial_variables"])
19
+ self.timeline = cfg["timeline"]
20
+ self.price_noise_scale = cfg.get("price_noise_scale", 0.0)
21
+ self.initial_cash = cfg.get("initial_cash", 10000.0)
22
+ self.reset()
23
+
24
+ def reset(self):
25
+ self.t = 0
26
+ self.cash = self.initial_cash
27
+ self.positions = np.zeros(len(self.stocks))
28
+ self.prices = self.initial_prices.copy()
29
+ self.variables_state = self.initial_variables.copy()
30
+ self.next_day_news = self.timeline.get("day_1", None)
31
+ return self._get_observation()
32
+
33
+ def _get_observation(self):
34
+ obs = {
35
+ "day": self.t,
36
+ "prices": {s: float(p) for s, p in zip(self.stocks, self.prices)},
37
+ "cash": float(self.cash),
38
+ "positions": {s: int(pos) for s, pos in zip(self.stocks, self.positions)},
39
+ "total_value": float(self.cash + np.sum(self.positions * self.prices)),
40
+ "news_next_day": self.next_day_news["variable_changes"] if self.next_day_news else None,
41
+ "news_next_day_text": self.next_day_news["news_text"] if self.next_day_news else None
42
+ }
43
+ return obs
44
+
45
+ def step(self, action):
46
+ assert isinstance(action, dict)
47
+
48
+ # Execute sells first
49
+ for stock, qty in action.get("sell", {}).items():
50
+ idx = self.stocks.index(stock)
51
+ qty = int(qty)
52
+ qty = min(qty, self.positions[idx])
53
+ revenue = self.prices[idx] * qty
54
+ self.positions[idx] -= qty
55
+ self.cash += revenue
56
+
57
+ # Then buys
58
+ for stock, qty in action.get("buy", {}).items():
59
+ idx = self.stocks.index(stock)
60
+ qty = int(qty)
61
+ cost = self.prices[idx] * qty
62
+ if cost <= self.cash:
63
+ self.positions[idx] += qty
64
+ self.cash -= cost
65
+
66
+ # Advance one day
67
+ self.t += 1
68
+ done = self.t >= self.num_days
69
+
70
+ # Update variable states & prices
71
+ if not done:
72
+ news_today = self.timeline.get(f"day_{self.t}", None)
73
+ if news_today:
74
+ deltas = np.array(news_today["variable_changes"])
75
+ self.variables_state += deltas
76
+ self._update_prices_from_variables(deltas)
77
+
78
+ # Prepare next day's news
79
+ self.next_day_news = self.timeline.get(f"day_{self.t + 1}", None) if not done else None
80
+
81
+ reward = self._compute_reward()
82
+ obs = self._get_observation()
83
+ return obs, reward, done, {}
84
+
85
+ def _update_prices_from_variables(self, delta_vars):
86
+ delta_price = self.dependency_matrix @ delta_vars
87
+ noise = np.zeros_like(delta_price) if self.price_noise_scale == 0 else np.random.normal(
88
+ 0, self.price_noise_scale, len(self.stocks)
89
+ )
90
+ self.prices += delta_price + noise
91
+ self.prices = np.clip(self.prices, 0.1, None)
92
+
93
+ def _compute_reward(self):
94
+ total_value = self.cash + np.sum(self.positions * self.prices)
95
+ return round(float(total_value), 2)
96
+
97
+
98
+ # Default configuration
99
+ DEFAULT_CONFIG = {
100
+ "num_days": 30,
101
+ "stocks": ["TECH", "ENERGY", "FINANCE"],
102
+ "variables": ["interest_rate", "oil_price", "market_sentiment"],
103
+ "dependency_matrix": [
104
+ [-5, 2, 3],
105
+ [1, 8, 2],
106
+ [-3, 1, 4]
107
+ ],
108
+ "initial_prices": [100, 80, 120],
109
+ "initial_variables": [0, 0, 0],
110
+ "initial_cash": 10000,
111
+ "price_noise_scale": 0,
112
+ "timeline": {
113
+ "day_1": {
114
+ "variable_changes": [0.1, -0.2, 0.3],
115
+ "news_text": "Federal Reserve hints at rate increase; Oil prices drop on oversupply concerns"
116
+ },
117
+ "day_2": {
118
+ "variable_changes": [-0.1, 0.3, 0.2],
119
+ "news_text": "Tech sector shows strong earnings; Energy stocks rally on production cuts"
120
+ },
121
+ "day_3": {
122
+ "variable_changes": [0.2, 0.1, -0.1],
123
+ "news_text": "Market sentiment cautious amid geopolitical tensions"
124
+ },
125
+ "day_4": {
126
+ "variable_changes": [0.0, 0.2, 0.1],
127
+ "news_text": "Stable interest rates; Energy sector momentum continues"
128
+ },
129
+ "day_5": {
130
+ "variable_changes": [-0.2, -0.1, 0.0],
131
+ "news_text": "Rate cut speculation; Market consolidation"
132
+ }
133
+ }
134
+ }
135
+
136
+ # Global state
137
+ env = None
138
+ history = []
139
+
140
+ def initialize_env(config_file=None):
141
+ global env, history
142
+
143
+ if config_file is not None:
144
+ try:
145
+ config = json.loads(config_file)
146
+ except:
147
+ return "โŒ Invalid JSON file", None, None, None, None
148
+ else:
149
+ config = DEFAULT_CONFIG
150
+
151
+ env = TradeArenaEnv_Deterministic(config)
152
+ obs = env.reset()
153
+
154
+ # Initialize history
155
+ history = [{
156
+ 'day': obs['day'],
157
+ 'total_value': obs['total_value'],
158
+ **obs['prices']
159
+ }]
160
+
161
+ status = f"โœ… Session initialized!\n๐Ÿ“… Day: {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
162
+
163
+ return (
164
+ status,
165
+ create_portfolio_display(obs),
166
+ create_news_display(obs),
167
+ create_price_chart(),
168
+ create_value_chart()
169
+ )
170
+
171
+ def create_portfolio_display(obs):
172
+ """Create portfolio summary table"""
173
+ data = []
174
+ for stock in env.stocks:
175
+ data.append({
176
+ 'Stock': stock,
177
+ 'Price': f"${obs['prices'][stock]:.2f}",
178
+ 'Holdings': obs['positions'][stock],
179
+ 'Value': f"${obs['prices'][stock] * obs['positions'][stock]:.2f}"
180
+ })
181
+
182
+ df = pd.DataFrame(data)
183
+ return df
184
+
185
+ def create_news_display(obs):
186
+ """Create news display"""
187
+ if obs['news_next_day_text']:
188
+ news_html = f"""
189
+ <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
190
+ padding: 20px; border-radius: 10px; color: white;'>
191
+ <h3>๐Ÿ“ฐ Next Day News</h3>
192
+ <p style='font-size: 16px;'>{obs['news_next_day_text']}</p>
193
+ """
194
+ if obs['news_next_day']:
195
+ news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b> "
196
+ for i, var in enumerate(env.variables):
197
+ change = obs['news_next_day'][i]
198
+ news_html += f"{var}: {'+' if change > 0 else ''}{change} | "
199
+ news_html += "</p>"
200
+ news_html += "</div>"
201
+ return news_html
202
+ else:
203
+ return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px;'>No more news available</div>"
204
+
205
+ def create_price_chart():
206
+ """Create price history chart"""
207
+ if len(history) <= 1:
208
+ return None
209
+
210
+ df = pd.DataFrame(history)
211
+
212
+ fig = go.Figure()
213
+ colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
214
+
215
+ for i, stock in enumerate(env.stocks):
216
+ fig.add_trace(go.Scatter(
217
+ x=df['day'],
218
+ y=df[stock],
219
+ mode='lines+markers',
220
+ name=stock,
221
+ line=dict(width=3, color=colors[i % len(colors)])
222
+ ))
223
+
224
+ fig.update_layout(
225
+ title='Stock Price History',
226
+ xaxis_title='Day',
227
+ yaxis_title='Price ($)',
228
+ hovermode='x unified',
229
+ template='plotly_white',
230
+ height=400
231
+ )
232
+
233
+ return fig
234
+
235
+ def create_value_chart():
236
+ """Create portfolio value chart"""
237
+ if len(history) <= 1:
238
+ return None
239
+
240
+ df = pd.DataFrame(history)
241
+
242
+ fig = go.Figure()
243
+ fig.add_trace(go.Scatter(
244
+ x=df['day'],
245
+ y=df['total_value'],
246
+ mode='lines+markers',
247
+ name='Portfolio Value',
248
+ line=dict(width=4, color='#8b5cf6'),
249
+ fill='tozeroy',
250
+ fillcolor='rgba(139, 92, 246, 0.1)'
251
+ ))
252
+
253
+ fig.update_layout(
254
+ title='Portfolio Value Over Time',
255
+ xaxis_title='Day',
256
+ yaxis_title='Total Value ($)',
257
+ template='plotly_white',
258
+ height=400
259
+ )
260
+
261
+ return fig
262
+
263
+ def execute_trade(stock, action, amount):
264
+ """Execute a buy or sell trade"""
265
+ global env, history
266
+
267
+ if env is None:
268
+ return "โŒ Please initialize the environment first", None, None, None, None
269
+
270
+ try:
271
+ amount = int(amount)
272
+ if amount <= 0:
273
+ return "โŒ Amount must be positive", None, None, None, None
274
+
275
+ if action == "Buy":
276
+ trade_action = {"buy": {stock: amount}, "sell": {}}
277
+ else:
278
+ trade_action = {"buy": {}, "sell": {stock: amount}}
279
+
280
+ # Execute trade (modify positions without advancing day)
281
+ if action == "Sell":
282
+ idx = env.stocks.index(stock)
283
+ qty = min(amount, env.positions[idx])
284
+ if qty == 0:
285
+ return f"โŒ No shares to sell", None, None, None, None
286
+ revenue = env.prices[idx] * qty
287
+ env.positions[idx] -= qty
288
+ env.cash += revenue
289
+ status = f"โœ… Sold {qty} shares of {stock} at ${env.prices[idx]:.2f}"
290
+ else: # Buy
291
+ idx = env.stocks.index(stock)
292
+ cost = env.prices[idx] * amount
293
+ if cost > env.cash:
294
+ return f"โŒ Insufficient cash! Need ${cost:.2f}, have ${env.cash:.2f}", None, None, None, None
295
+ env.positions[idx] += amount
296
+ env.cash -= cost
297
+ status = f"โœ… Bought {amount} shares of {stock} at ${env.prices[idx]:.2f}"
298
+
299
+ obs = env._get_observation()
300
+ status += f"\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
301
+
302
+ return (
303
+ status,
304
+ create_portfolio_display(obs),
305
+ create_news_display(obs),
306
+ create_price_chart(),
307
+ create_value_chart()
308
+ )
309
+
310
+ except Exception as e:
311
+ return f"โŒ Error: {str(e)}", None, None, None, None
312
+
313
+ def advance_day():
314
+ """Advance to next day"""
315
+ global env, history
316
+
317
+ if env is None:
318
+ return "โŒ Please initialize the environment first", None, None, None, None
319
+
320
+ try:
321
+ obs, reward, done, info = env.step({"buy": {}, "sell": {}})
322
+
323
+ # Add to history
324
+ history.append({
325
+ 'day': obs['day'],
326
+ 'total_value': obs['total_value'],
327
+ **obs['prices']
328
+ })
329
+
330
+ if done:
331
+ status = f"๐Ÿ Simulation Complete!\n๐Ÿ“… Final Day: {obs['day']}\n๐Ÿ’ฐ Final Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Final Value: ${obs['total_value']:.2f}"
332
+ else:
333
+ status = f"โœ… Advanced to Day {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
334
+
335
+ return (
336
+ status,
337
+ create_portfolio_display(obs),
338
+ create_news_display(obs),
339
+ create_price_chart(),
340
+ create_value_chart()
341
+ )
342
+
343
+ except Exception as e:
344
+ return f"โŒ Error: {str(e)}", None, None, None, None
345
+
346
+ def reset_env():
347
+ """Reset the environment"""
348
+ global env, history
349
+
350
+ if env is None:
351
+ return initialize_env()
352
+
353
+ obs = env.reset()
354
+ history = [{
355
+ 'day': obs['day'],
356
+ 'total_value': obs['total_value'],
357
+ **obs['prices']
358
+ }]
359
+
360
+ status = f"๐Ÿ”„ Environment Reset!\n๐Ÿ“… Day: {obs['day']}\n๐Ÿ’ฐ Cash: ${obs['cash']:.2f}\n๐Ÿ“Š Total Value: ${obs['total_value']:.2f}"
361
+
362
+ return (
363
+ status,
364
+ create_portfolio_display(obs),
365
+ create_news_display(obs),
366
+ None,
367
+ None
368
+ )
369
+
370
+ # Create Gradio Interface
371
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Trading Arena") as demo:
372
+ gr.Markdown(
373
+ """
374
+ # ๐Ÿš€ AI Trading Arena
375
+ ### Interactive Stock Trading Simulator
376
+ Upload your config or use the default configuration to start trading!
377
+ """
378
+ )
379
+
380
+ with gr.Row():
381
+ with gr.Column(scale=1):
382
+ gr.Markdown("## ๐ŸŽฎ Control Panel")
383
+
384
+ with gr.Accordion("๐Ÿ“ Configuration", open=True):
385
+ config_input = gr.Textbox(
386
+ label="Upload Config JSON (or leave empty for default)",
387
+ placeholder='Paste JSON config here or leave empty',
388
+ lines=5
389
+ )
390
+ init_btn = gr.Button("๐Ÿš€ Initialize/Load Config", variant="primary", size="lg")
391
+ reset_btn = gr.Button("๐Ÿ”„ Reset Environment", variant="secondary")
392
+
393
+ with gr.Accordion("๐Ÿ’น Trading", open=True):
394
+ stock_dropdown = gr.Dropdown(
395
+ choices=DEFAULT_CONFIG["stocks"],
396
+ label="Select Stock",
397
+ value=DEFAULT_CONFIG["stocks"][0]
398
+ )
399
+ action_radio = gr.Radio(
400
+ choices=["Buy", "Sell"],
401
+ label="Action",
402
+ value="Buy"
403
+ )
404
+ amount_input = gr.Number(
405
+ label="Amount (shares)",
406
+ value=1,
407
+ minimum=1
408
+ )
409
+ trade_btn = gr.Button("๐Ÿ“ˆ Execute Trade", variant="primary")
410
+
411
+ advance_btn = gr.Button("โญ๏ธ Advance to Next Day", variant="primary", size="lg")
412
+
413
+ status_output = gr.Textbox(
414
+ label="๐Ÿ“Š Status",
415
+ lines=5,
416
+ interactive=False
417
+ )
418
+
419
+ with gr.Column(scale=2):
420
+ gr.Markdown("## ๐Ÿ“Š Market Overview")
421
+
422
+ portfolio_table = gr.Dataframe(
423
+ label="Portfolio Holdings",
424
+ interactive=False
425
+ )
426
+
427
+ news_display = gr.HTML(label="News")
428
+
429
+ with gr.Tabs():
430
+ with gr.Tab("๐Ÿ“ˆ Price History"):
431
+ price_chart = gr.Plot(label="Stock Prices")
432
+
433
+ with gr.Tab("๐Ÿ’ฐ Portfolio Value"):
434
+ value_chart = gr.Plot(label="Total Value")
435
+
436
+ # Event handlers
437
+ init_btn.click(
438
+ fn=initialize_env,
439
+ inputs=[config_input],
440
+ outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]
441
+ )
442
+
443
+ reset_btn.click(
444
+ fn=reset_env,
445
+ inputs=[],
446
+ outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]
447
+ )
448
+
449
+ trade_btn.click(
450
+ fn=execute_trade,
451
+ inputs=[stock_dropdown, action_radio, amount_input],
452
+ outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]
453
+ )
454
+
455
+ advance_btn.click(
456
+ fn=advance_day,
457
+ inputs=[],
458
+ outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]
459
+ )
460
+
461
+ # Initialize on load
462
+ demo.load(
463
+ fn=initialize_env,
464
+ inputs=[],
465
+ outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]
466
+ )
467
+
468
+ if __name__ == "__main__":
469
+ demo.launch()