Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import json | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import os | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| matplotlib.use('Agg') | |
| class TradeArenaEnv_Deterministic: | |
| """ | |
| Odyssey Arena - AI Trading Environment (Deterministic version) | |
| """ | |
| def __init__(self, cfg): | |
| self.num_days = cfg["num_days"] | |
| self.stocks = cfg["stocks"] | |
| self.variables = cfg["variables"] | |
| self.dependency_matrix = np.array(cfg["dependency_matrix"]) | |
| self.initial_prices = np.array(cfg["initial_prices"]) | |
| self.initial_variables = np.array(cfg["initial_variables"]) | |
| self.timeline = cfg["timeline"] | |
| self.price_noise_scale = cfg.get("price_noise_scale", 0.0) | |
| self.initial_cash = cfg.get("initial_cash", 10000.0) | |
| self.reset() | |
| def reset(self): | |
| self.t = 0 | |
| self.cash = self.initial_cash | |
| self.positions = np.zeros(len(self.stocks), dtype=np.float64) | |
| self.prices = self.initial_prices.copy().astype(np.float64) | |
| self.variables_state = self.initial_variables.copy().astype(np.float64) | |
| self.next_day_news = self.timeline.get("day_1", None) | |
| return self._get_observation() | |
| def _get_observation(self): | |
| obs = { | |
| "day": self.t, | |
| "prices": {s: float(p) for s, p in zip(self.stocks, self.prices)}, | |
| "cash": float(self.cash), | |
| "positions": {s: int(pos) for s, pos in zip(self.stocks, self.positions)}, | |
| "total_value": float(self.cash + np.sum(self.positions * self.prices)), | |
| "news_next_day": self.next_day_news["variable_changes"] if self.next_day_news else None, | |
| "news_next_day_text": self.next_day_news["news_text"] if self.next_day_news else None | |
| } | |
| return obs | |
| def step(self, action): | |
| assert isinstance(action, dict) | |
| # Execute sells first | |
| for stock, qty in action.get("sell", {}).items(): | |
| idx = self.stocks.index(stock) | |
| qty = int(qty) | |
| qty = min(qty, self.positions[idx]) | |
| revenue = self.prices[idx] * qty | |
| self.positions[idx] -= qty | |
| self.cash += revenue | |
| # Then buys | |
| for stock, qty in action.get("buy", {}).items(): | |
| idx = self.stocks.index(stock) | |
| qty = int(qty) | |
| cost = self.prices[idx] * qty | |
| if cost <= self.cash: | |
| self.positions[idx] += qty | |
| self.cash -= cost | |
| # Advance one day | |
| self.t += 1 | |
| done = self.t >= self.num_days | |
| # Update variable states & prices | |
| if not done: | |
| news_today = self.timeline.get(f"day_{self.t}", None) | |
| if news_today: | |
| deltas = np.array(news_today["variable_changes"]) | |
| self.variables_state += deltas | |
| self._update_prices_from_variables(deltas) | |
| # Prepare next day's news | |
| self.next_day_news = self.timeline.get(f"day_{self.t + 1}", None) if not done else None | |
| reward = self._compute_reward() | |
| obs = self._get_observation() | |
| return obs, reward, done, {} | |
| def _update_prices_from_variables(self, delta_vars): | |
| delta_price = self.dependency_matrix @ delta_vars | |
| noise = np.zeros_like(delta_price) if self.price_noise_scale == 0 else np.random.normal( | |
| 0, self.price_noise_scale, len(self.stocks) | |
| ) | |
| self.prices += delta_price + noise | |
| self.prices = np.clip(self.prices, 0.1, None) | |
| def _compute_reward(self): | |
| total_value = self.cash + np.sum(self.positions * self.prices) | |
| return round(float(total_value), 2) | |
| # Default configuration | |
| DEFAULT_CONFIG = { | |
| "num_days": 30, | |
| "stocks": ["TECH", "ENERGY", "FINANCE"], | |
| "variables": ["interest_rate", "oil_price", "market_sentiment"], | |
| "dependency_matrix": [ | |
| [-5, 2, 3], | |
| [1, 8, 2], | |
| [-3, 1, 4] | |
| ], | |
| "initial_prices": [100, 80, 120], | |
| "initial_variables": [0, 0, 0], | |
| "initial_cash": 10000, | |
| "price_noise_scale": 0, | |
| "timeline": { | |
| "day_1": { | |
| "variable_changes": [0.1, -0.2, 0.3], | |
| "news_text": "Federal Reserve hints at rate increase; Oil prices drop on oversupply concerns" | |
| }, | |
| "day_2": { | |
| "variable_changes": [-0.1, 0.3, 0.2], | |
| "news_text": "Tech sector shows strong earnings; Energy stocks rally on production cuts" | |
| }, | |
| "day_3": { | |
| "variable_changes": [0.2, 0.1, -0.1], | |
| "news_text": "Market sentiment cautious amid geopolitical tensions" | |
| }, | |
| "day_4": { | |
| "variable_changes": [0.0, 0.2, 0.1], | |
| "news_text": "Stable interest rates; Energy sector momentum continues" | |
| }, | |
| "day_5": { | |
| "variable_changes": [-0.2, -0.1, 0.0], | |
| "news_text": "Rate cut speculation; Market consolidation" | |
| } | |
| } | |
| } | |
| # ===== ๆฐๅข: config ็ฎๅฝๆฏๆ ===== | |
| def list_config_files(): | |
| config_dir = "config" | |
| if not os.path.exists(config_dir): | |
| return [] | |
| return [f for f in os.listdir(config_dir) if f.endswith(".json")] | |
| def load_config_from_file(filename): | |
| """ๅ ่ฝฝconfig็ฎๅฝไธ็jsonๆไปถๅฐ่พๅ ฅๆก""" | |
| try: | |
| path = os.path.join("config", filename) | |
| with open(path, "r") as f: | |
| cfg = json.load(f) | |
| return json.dumps(cfg, indent=2) | |
| except Exception as e: | |
| return f"โ Failed to load {filename}: {str(e)}" | |
| # Global state | |
| env = None | |
| history = [] | |
| def initialize_env(config_file=None): | |
| global env, history | |
| if config_file is not None and config_file.strip(): | |
| try: | |
| config = json.loads(config_file) | |
| except: | |
| return "โ Invalid JSON file", None, None, None, None | |
| else: | |
| config = DEFAULT_CONFIG | |
| env = TradeArenaEnv_Deterministic(config) | |
| obs = env.reset() | |
| # Initialize history | |
| history = [{ | |
| 'day': obs['day'], | |
| 'total_value': obs['total_value'], | |
| **obs['prices'] | |
| }] | |
| status = f"โ Session initialized!\n๐ Day: {obs['day']}\n๐ฐ Cash: ${obs['cash']:.2f}\n๐ Total Value: ${obs['total_value']:.2f}" | |
| return ( | |
| status, | |
| create_portfolio_display(obs), | |
| create_news_display(obs), | |
| create_price_chart(), | |
| create_value_chart() | |
| ) | |
| def create_portfolio_display(obs): | |
| data = [] | |
| for stock in env.stocks: | |
| data.append({ | |
| 'Stock': stock, | |
| 'Price': f"${obs['prices'][stock]:.2f}", | |
| 'Holdings': obs['positions'][stock], | |
| 'Value': f"${obs['prices'][stock] * obs['positions'][stock]:.2f}" | |
| }) | |
| df = pd.DataFrame(data) | |
| return df | |
| def create_news_display(obs): | |
| if obs['news_next_day_text']: | |
| news_html = f""" | |
| <div style='background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 20px; border-radius: 10px; color: white; margin: 10px 0;'> | |
| <h3 style='margin-top: 0;'>๐ฐ Next Day News</h3> | |
| <p style='font-size: 16px; line-height: 1.6;'>{obs['news_next_day_text']}</p> | |
| """ | |
| if obs['news_next_day']: | |
| news_html += "<p style='font-size: 14px; margin-top: 10px;'><b>Variable Changes:</b><br/>" | |
| for i, var in enumerate(env.variables): | |
| change = obs['news_next_day'][i] | |
| news_html += f"โข {var}: <b>{'+' if change > 0 else ''}{change}</b><br/>" | |
| news_html += "</p>" | |
| news_html += "</div>" | |
| return news_html | |
| else: | |
| return "<div style='padding: 20px; background: #f0f0f0; border-radius: 10px; text-align: center;'>๐ญ No more news available</div>" | |
| # def create_price_chart(): | |
| # if len(history) <= 1: | |
| # fig, ax = plt.subplots(figsize=(10, 6)) | |
| # ax.text(0.5, 0.5, 'Trade to see price history', | |
| # ha='center', va='center', fontsize=14, color='gray') | |
| # ax.axis('off') | |
| # return fig | |
| # df = pd.DataFrame(history) | |
| # fig, ax = plt.subplots(figsize=(10, 6)) | |
| # colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] | |
| # for i, stock in enumerate(env.stocks): | |
| # ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock) | |
| # ax.set_xlabel('Day') | |
| # ax.set_ylabel('Price ($)') | |
| # ax.set_title('Stock Price History') | |
| # ax.legend() | |
| # ax.grid(True, alpha=0.3) | |
| # return fig | |
| # def create_price_chart(): | |
| # """Create individual price chart for each stock""" | |
| # if len(history) <= 1: | |
| # fig, axs = plt.subplots(1, 1, figsize=(10, 6)) | |
| # axs.text(0.5, 0.5, 'Trade to see price history', | |
| # ha='center', va='center', fontsize=14, color='gray') | |
| # axs.axis('off') | |
| # return fig | |
| # df = pd.DataFrame(history) | |
| # num_stocks = len(env.stocks) | |
| # fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True) | |
| # # ๅฆๆๅชๆไธไธช่ก็ฅจ๏ผaxsไธๆฏๆฐ็ป๏ผ้่ฆๅค็ | |
| # if num_stocks == 1: | |
| # axs = [axs] | |
| # colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] | |
| # for i, stock in enumerate(env.stocks): | |
| # ax = axs[i] | |
| # ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock) | |
| # ax.set_ylabel(f'{stock} ($)') | |
| # ax.set_title(f'{stock} Price History') | |
| # ax.legend(loc='best', framealpha=0.8) | |
| # ax.grid(True, alpha=0.3) | |
| # axs[-1].set_xlabel('Day') | |
| # plt.tight_layout() | |
| # return fig | |
| def create_price_chart(): | |
| """Create stock price chart using Plotly""" | |
| if len(history) <= 1: | |
| # ๆฒกๆไบคๆๅๅฒๆถ๏ผ่ฟๅ็ฉบ็ฝๅพ | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="Trade to see price history", | |
| xref="paper", yref="paper", | |
| showarrow=False, | |
| font=dict(size=16, color="gray") | |
| ) | |
| fig.update_layout( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| template="plotly_white", | |
| height=400 | |
| ) | |
| return fig | |
| df = pd.DataFrame(history) | |
| fig = go.Figure() | |
| colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6'] | |
| for i, stock in enumerate(env.stocks): | |
| fig.add_trace(go.Scatter( | |
| x=df['day'], | |
| y=df[stock], | |
| mode='lines+markers', | |
| name=stock, | |
| line=dict(color=colors[i % len(colors)], width=2), | |
| marker=dict(size=6) | |
| )) | |
| fig.update_layout( | |
| title="Stock Price History", | |
| xaxis_title="Day", | |
| yaxis_title="Price ($)", | |
| template="plotly_white", | |
| legend=dict(title="Stocks", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1), | |
| height=400 + 50 * len(env.stocks) | |
| ) | |
| return fig | |
| def create_value_chart(): | |
| if len(history) <= 1: | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.text(0.5, 0.5, 'Trade to see portfolio value', ha='center', va='center', fontsize=14, color='gray') | |
| ax.axis('off') | |
| return fig | |
| df = pd.DataFrame(history) | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| ax.plot(df['day'], df['total_value'], marker='o', linewidth=3, color='#8b5cf6', label='Portfolio Value') | |
| ax.fill_between(df['day'], df['total_value'], alpha=0.2, color='#8b5cf6') | |
| initial_value = history[0]['total_value'] | |
| ax.axhline(y=initial_value, color='red', linestyle='--', alpha=0.5, label=f'Initial: ${initial_value:.2f}') | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| return fig | |
| def execute_trade(stock, action, amount): | |
| global env, history | |
| if env is None: | |
| return "โ Please initialize the environment first", None, None, None, None | |
| try: | |
| amount = int(amount) | |
| if amount <= 0: | |
| return "โ Amount must be positive", None, None, None, None | |
| if action == "Buy": | |
| idx = env.stocks.index(stock) | |
| cost = env.prices[idx] * amount | |
| if cost > env.cash: | |
| return f"โ Insufficient cash!\nNeed: ${cost:.2f}\nHave: ${env.cash:.2f}", None, None, None, None | |
| env.positions[idx] += amount | |
| env.cash -= cost | |
| status = f"โ Bought {amount} {stock} at ${env.prices[idx]:.2f}" | |
| else: | |
| idx = env.stocks.index(stock) | |
| qty = min(amount, env.positions[idx]) | |
| if qty == 0: | |
| return "โ No shares to sell", None, None, None, None | |
| revenue = env.prices[idx] * qty | |
| env.positions[idx] -= qty | |
| env.cash += revenue | |
| status = f"โ Sold {qty} {stock} at ${env.prices[idx]:.2f}" | |
| obs = env._get_observation() | |
| status += f"\n๐ฐ Cash: ${obs['cash']:.2f} | Total Value: ${obs['total_value']:.2f}" | |
| return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart() | |
| except Exception as e: | |
| return f"โ Error: {str(e)}", None, None, None, None | |
| def advance_day(): | |
| global env, history | |
| if env is None: | |
| return "โ Please initialize the environment first", None, None, None, None | |
| obs, reward, done, _ = env.step({"buy": {}, "sell": {}}) | |
| history.append({'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']}) | |
| if done: | |
| init_val = history[0]['total_value'] | |
| profit = obs['total_value'] - init_val | |
| pct = profit / init_val * 100 | |
| status = f"๐ Finished! Final value: ${obs['total_value']:.2f}\nProfit: {profit:+.2f} ({pct:+.2f}%)" | |
| else: | |
| status = f"โ Advanced to Day {obs['day']} | ๐ฐ Cash: ${obs['cash']:.2f} | ๐ Value: ${obs['total_value']:.2f}" | |
| return status, create_portfolio_display(obs), create_news_display(obs), create_price_chart(), create_value_chart() | |
| def reset_env(): | |
| global env, history | |
| if env is None: | |
| return initialize_env() | |
| obs = env.reset() | |
| history = [{'day': obs['day'], 'total_value': obs['total_value'], **obs['prices']}] | |
| status = f"๐ Environment Reset! Day {obs['day']}, Cash ${obs['cash']:.2f}" | |
| return status, create_portfolio_display(obs), create_news_display(obs), None, None | |
| # ======== UI้จๅ ======== | |
| custom_css = """ | |
| .gradio-container { font-family: 'Arial', sans-serif; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AI Trading Arena") as demo: | |
| gr.Markdown("# ๐ AI Trading Arena\n### Interactive Stock Trading Simulator") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## ๐ฎ Control Panel") | |
| with gr.Accordion("๐ Configuration", open=False): | |
| gr.Markdown("Select a config file from `/config` folder or paste custom JSON") | |
| config_file_dropdown = gr.Dropdown( | |
| choices=list_config_files(), | |
| label="Choose Config File", | |
| value=list_config_files()[0] if list_config_files() else None | |
| ) | |
| load_file_btn = gr.Button("๐ Load from File", variant="secondary") | |
| config_input = gr.Textbox( | |
| label="Custom Config JSON", | |
| placeholder='{"num_days": 30, "stocks": ["TECH", "ENERGY"], ...}', | |
| lines=4 | |
| ) | |
| init_btn = gr.Button("๐ Load Config", variant="primary", size="lg") | |
| with gr.Accordion("๐น Trading Operations", open=True): | |
| stock_dropdown = gr.Dropdown( | |
| choices=["S0", "S1", "S2", "S3", "S4", "S5"], | |
| label="Select Stock", | |
| value="S0" | |
| ) | |
| action_radio = gr.Radio(choices=["Buy", "Sell"], label="Action", value="Buy") | |
| amount_input = gr.Number(label="Amount (shares)", value=10, minimum=1, step=1) | |
| trade_btn = gr.Button("๐ Execute Trade", variant="primary", size="lg") | |
| with gr.Row(): | |
| advance_btn = gr.Button("โญ๏ธ Next Day", variant="primary", size="lg") | |
| reset_btn = gr.Button("๐ Reset", variant="secondary", size="lg") | |
| status_output = gr.Textbox(label="๐ Status & Messages", lines=8, interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## ๐ Market Dashboard") | |
| portfolio_table = gr.Dataframe(label="๐ผ Portfolio Holdings", interactive=False) | |
| news_display = gr.HTML(label="๐ฐ Market News") | |
| with gr.Tab("๐ Price History"): | |
| price_chart = gr.Plot(label="Stock Prices Over Time") | |
| with gr.Tab("๐ฐ Portfolio Value"): | |
| value_chart = gr.Plot(label="Total Portfolio Value") | |
| # ็ปๅฎ้ป่พ | |
| load_file_btn.click(fn=load_config_from_file, inputs=[config_file_dropdown], outputs=[config_input]) | |
| init_btn.click(fn=initialize_env, inputs=[config_input], | |
| outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) | |
| reset_btn.click(fn=reset_env, inputs=[], | |
| outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) | |
| trade_btn.click(fn=execute_trade, inputs=[stock_dropdown, action_radio, amount_input], | |
| outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) | |
| advance_btn.click(fn=advance_day, inputs=[], | |
| outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) | |
| demo.load(fn=initialize_env, inputs=[], | |
| outputs=[status_output, portfolio_table, news_display, price_chart, value_chart]) | |
| if __name__ == "__main__": | |
| demo.launch() | |