import gradio as gr import numpy as np import json import pandas as pd import matplotlib.pyplot as plt import matplotlib import os import plotly.graph_objects as go import pandas as pd matplotlib.use('Agg') class TradeArenaEnv_Deterministic: """ Odyssey Arena - AI Trading Environment (Deterministic version) """ def __init__(self, cfg): self.num_days = cfg["num_days"] self.stocks = cfg["stocks"] self.variables = cfg["variables"] self.dependency_matrix = np.array(cfg["dependency_matrix"]) self.initial_prices = np.array(cfg["initial_prices"]) self.initial_variables = np.array(cfg["initial_variables"]) self.timeline = cfg["timeline"] self.price_noise_scale = cfg.get("price_noise_scale", 0.0) self.initial_cash = cfg.get("initial_cash", 10000.0) self.reset() def reset(self): self.t = 0 self.cash = self.initial_cash self.positions = np.zeros(len(self.stocks), dtype=np.float64) self.prices = self.initial_prices.copy().astype(np.float64) self.variables_state = self.initial_variables.copy().astype(np.float64) self.next_day_news = self.timeline.get("day_1", None) return self._get_observation() def _get_observation(self): obs = { "day": self.t, "prices": {s: float(p) for s, p in zip(self.stocks, self.prices)}, "cash": float(self.cash), "positions": {s: int(pos) for s, pos in zip(self.stocks, self.positions)}, "total_value": float(self.cash + np.sum(self.positions * self.prices)), "news_next_day": self.next_day_news["variable_changes"] if self.next_day_news else None, "news_next_day_text": self.next_day_news["news_text"] if self.next_day_news else None } return obs def step(self, action): assert isinstance(action, dict) # Execute sells first for stock, qty in action.get("sell", {}).items(): idx = self.stocks.index(stock) qty = int(qty) qty = min(qty, self.positions[idx]) revenue = self.prices[idx] * qty self.positions[idx] -= qty self.cash += revenue # Then buys for stock, qty in action.get("buy", {}).items(): idx = self.stocks.index(stock) qty = int(qty) cost = self.prices[idx] * qty if cost <= self.cash: self.positions[idx] += qty self.cash -= cost # Advance one day self.t += 1 done = self.t >= self.num_days # Update variable states & prices if not done: news_today = self.timeline.get(f"day_{self.t}", None) if news_today: deltas = np.array(news_today["variable_changes"]) self.variables_state += deltas self._update_prices_from_variables(deltas) # Prepare next day's news self.next_day_news = self.timeline.get(f"day_{self.t + 1}", None) if not done else None reward = self._compute_reward() obs = self._get_observation() return obs, reward, done, {} def _update_prices_from_variables(self, delta_vars): delta_price = self.dependency_matrix @ delta_vars noise = np.zeros_like(delta_price) if self.price_noise_scale == 0 else np.random.normal( 0, self.price_noise_scale, len(self.stocks) ) self.prices += delta_price + noise self.prices = np.clip(self.prices, 0.1, None) def _compute_reward(self): total_value = self.cash + np.sum(self.positions * self.prices) return round(float(total_value), 2) # Default configuration DEFAULT_CONFIG = { "num_days": 30, "stocks": ["TECH", "ENERGY", "FINANCE"], "variables": ["interest_rate", "oil_price", "market_sentiment"], "dependency_matrix": [ [-5, 2, 3], [1, 8, 2], [-3, 1, 4] ], "initial_prices": [100, 80, 120], "initial_variables": [0, 0, 0], "initial_cash": 10000, "price_noise_scale": 0, "timeline": { "day_1": { "variable_changes": [0.1, -0.2, 0.3], "news_text": "Federal Reserve hints at rate increase; Oil prices drop on oversupply concerns" }, "day_2": { "variable_changes": [-0.1, 0.3, 0.2], "news_text": "Tech sector shows strong earnings; Energy stocks rally on production cuts" }, "day_3": { "variable_changes": [0.2, 0.1, -0.1], "news_text": "Market sentiment cautious amid geopolitical tensions" }, "day_4": { "variable_changes": [0.0, 0.2, 0.1], "news_text": "Stable interest rates; Energy sector momentum continues" }, "day_5": { "variable_changes": [-0.2, -0.1, 0.0], "news_text": "Rate cut speculation; Market consolidation" } } } # ===== 新增: config 目录支持 ===== def list_config_files(): config_dir = "config" if not os.path.exists(config_dir): return [] return [f for f in os.listdir(config_dir) if f.endswith(".json")] def load_config_from_file(filename): """加载config目录下的json文件到输入框""" try: path = os.path.join("config", filename) with open(path, "r") as f: cfg = json.load(f) return json.dumps(cfg, indent=2) except Exception as e: return f"❌ Failed to load {filename}: {str(e)}" # Global state env = None history = [] def initialize_env(config_file=None): global env, history if config_file is not None and config_file.strip(): try: config = json.loads(config_file) except: return "❌ Invalid JSON file", None, None, None, None else: config = DEFAULT_CONFIG env = TradeArenaEnv_Deterministic(config) obs = env.reset() # Initialize history history = [{ 'day': obs['day'], 'total_value': obs['total_value'], **obs['prices'] }] status = f"✅ Session initialized!\n📅 Day: {obs['day']}\n💰 Cash: ${obs['cash']:.2f}\n📊 Total Value: ${obs['total_value']:.2f}" return ( status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart() ) def create_portfolio_display(obs): data = [] for stock in env.stocks: data.append({ 'Stock': stock, 'Price': f"${obs['prices'][stock]:.2f}", 'Holdings': obs['positions'][stock], 'Value': f"${obs['prices'][stock] * obs['positions'][stock]:.2f}" }) df = pd.DataFrame(data) return df def create_news_display(obs): if obs['news_next_day_text']: news_html = f"""

📰 Next Day News

{obs['news_next_day_text']}

""" if obs['news_next_day']: news_html += "

Variable Changes:
" for i, var in enumerate(env.variables): change = obs['news_next_day'][i] news_html += f"• {var}: {'+' if change > 0 else ''}{change}
" news_html += "

" news_html += "
" return news_html else: return "
📭 No more news available
" # def create_price_chart(): # if len(history) <= 1: # fig, ax = plt.subplots(figsize=(10, 6)) # ax.text(0.5, 0.5, 'Trade to see price history', # ha='center', va='center', fontsize=14, color='gray') # ax.axis('off') # return fig # df = pd.DataFrame(history) # fig, ax = plt.subplots(figsize=(10, 6)) # colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] # for i, stock in enumerate(env.stocks): # ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock) # ax.set_xlabel('Day') # ax.set_ylabel('Price ($)') # ax.set_title('Stock Price History') # ax.legend() # ax.grid(True, alpha=0.3) # return fig # def create_price_chart(): # """Create individual price chart for each stock""" # if len(history) <= 1: # fig, axs = plt.subplots(1, 1, figsize=(10, 6)) # axs.text(0.5, 0.5, 'Trade to see price history', # ha='center', va='center', fontsize=14, color='gray') # axs.axis('off') # return fig # df = pd.DataFrame(history) # num_stocks = len(env.stocks) # fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True) # # 如果只有一个股票,axs不是数组,需要处理 # if num_stocks == 1: # axs = [axs] # colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] # for i, stock in enumerate(env.stocks): # ax = axs[i] # ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock) # ax.set_ylabel(f'{stock} ($)') # ax.set_title(f'{stock} Price History') # ax.legend(loc='best', framealpha=0.8) # ax.grid(True, alpha=0.3) # axs[-1].set_xlabel('Day') # plt.tight_layout() # return fig def create_price_chart(): """Create stock price chart using Plotly""" if len(history) <= 1: # 没有交易历史时,返回空白图 fig = go.Figure() fig.add_annotation( text="Trade to see price history", xref="paper", yref="paper", showarrow=False, font=dict(size=16, color="gray") ) fig.update_layout( xaxis=dict(visible=False), yaxis=dict(visible=False), template="plotly_white", height=400 ) return fig df = pd.DataFrame(history) fig = go.Figure() colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] for i, stock in enumerate(env.stocks): fig.add_trace(go.Scatter( x=df['day'], y=df[stock], mode='lines+markers', name=stock, line=dict(color=colors[i % len(colors)], width=2), marker=dict(size=6) )) fig.update_layout( title="Stock Price History", xaxis_title="Day", yaxis_title="Price ($)", template="plotly_white", legend=dict(title="Stocks", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), height=400 + 50 * len(env.stocks) ) return fig def create_value_chart(): if len(history) <= 1: fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, 'Trade to see portfolio value', ha='center', va='center', fontsize=14, color='gray') ax.axis('off') return fig df = pd.DataFrame(history) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(df['day'], df['total_value'], marker='o', linewidth=3, color='#8b5cf6', label='Portfolio Value') ax.fill_between(df['day'], df['total_value'], alpha=0.2, color='#8b5cf6') initial_value = history[0]['total_value'] ax.axhline(y=initial_value, color='red', linestyle='--', alpha=0.5, label=f'Initial: ${initial_value:.2f}') ax.legend() ax.grid(True, alpha=0.3) return fig def execute_trade(stock, action, amount): global env, history if env is None: return "❌ Please initialize the environment first", None, None, None, None try: amount = int(amount) if amount <= 0: return "❌ Amount must be positive", None, None, None, None if action == "Buy": idx = env.stocks.index(stock) cost = env.prices[idx] * amount if cost > env.cash: return f"❌ Insufficient cash!\nNeed: ${cost:.2f}\nHave: ${env.cash:.2f}", None, None, None, None env.positions[idx] += amount env.cash -= cost status = f"✅ Bought {amount} {stock} at ${env.prices[idx]:.2f}" else: idx = env.stocks.index(stock) qty = min(amount, env.positions[idx]) if qty == 0: return "❌ No shares to sell", None, None, None, None revenue = env.prices[idx] * qty env.positions[idx] -= qty env.cash += revenue status = f"✅ Sold {qty} {stock} at ${env.prices[idx]:.2f}" obs = env._get_observation() status += f"\n💰 Cash: ${obs['cash']:.2f} | Total Value: ${obs['total_value']:.2f}" return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart() except Exception as e: return f"❌ Error: {str(e)}", None, None, None, None def advance_day(): global env, history if env is None: return "❌ Please initialize the environment first", None, None, None, None obs, reward, done, _ = env.step({"buy": {}, "sell": {}}) history.append({'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']}) if done: init_val = history[0]['total_value'] profit = obs['total_value'] - init_val pct = profit / init_val * 100 status = f"🏁 Finished! Final value: ${obs['total_value']:.2f}\nProfit: {profit:+.2f} ({pct:+.2f}%)" else: status = f"✅ Advanced to Day {obs['day']} | 💰 Cash: ${obs['cash']:.2f} | 📊 Value: ${obs['total_value']:.2f}" return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart() def reset_env(): global env, history if env is None: return initialize_env() obs = env.reset() history = [{'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']}] status = f"🔄 Environment Reset! Day {obs['day']}, Cash ${obs['cash']:.2f}" return status, create_portfolio_display(obs), create_news_display(obs), None, None # ======== UI部分 ======== custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } """ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AI Trading Arena") as demo: gr.Markdown("# 🚀 AI Trading Arena\n### Interactive Stock Trading Simulator") with gr.Row(): with gr.Column(scale=1): gr.Markdown("## 🎮 Control Panel") with gr.Accordion("📁 Configuration", open=False): gr.Markdown("Select a config file from `/config` folder or paste custom JSON") config_file_dropdown = gr.Dropdown( choices=list_config_files(), label="Choose Config File", value=list_config_files()[0] if list_config_files() else None ) load_file_btn = gr.Button("📂 Load from File", variant="secondary") config_input = gr.Textbox( label="Custom Config JSON", placeholder='{"num_days": 30, "stocks": ["TECH", "ENERGY"], ...}', lines=4 ) init_btn = gr.Button("🚀 Load Config", variant="primary", size="lg") with gr.Accordion("💹 Trading Operations", open=True): stock_dropdown = gr.Dropdown( choices=["S0", "S1", "S2", "S3", "S4", "S5"], label="Select Stock", value="S0" ) action_radio = gr.Radio(choices=["Buy", "Sell"], label="Action", value="Buy") amount_input = gr.Number(label="Amount (shares)", value=10, minimum=1, step=1) trade_btn = gr.Button("📈 Execute Trade", variant="primary", size="lg") with gr.Row(): advance_btn = gr.Button("⏭️ Next Day", variant="primary", size="lg") reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg") status_output = gr.Textbox(label="📊 Status & Messages", lines=8, interactive=False) with gr.Column(scale=2): gr.Markdown("## 📊 Market Dashboard") portfolio_table = gr.Dataframe(label="💼 Portfolio Holdings", interactive=False) news_display = gr.HTML(label="📰 Market News") with gr.Tab("📈 Price History"): price_chart = gr.Plot(label="Stock Prices Over Time") with gr.Tab("💰 Portfolio Value"): value_chart = gr.Plot(label="Total Portfolio Value") # 绑定逻辑 load_file_btn.click(fn=load_config_from_file, inputs=[config_file_dropdown], outputs=[config_input]) init_btn.click(fn=initialize_env, inputs=[config_input], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) reset_btn.click(fn=reset_env, inputs=[], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) trade_btn.click(fn=execute_trade, inputs=[stock_dropdown, action_radio, amount_input], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) advance_btn.click(fn=advance_day, inputs=[], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) demo.load(fn=initialize_env, inputs=[], outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) if __name__ == "__main__": demo.launch()