OdysseyArena / GUI_Trade_Task.py
beatccjiang's picture
Upload project files to Hugging Face Spaces
907121e verified
# ==================== Trade 任务模块 ====================
"""
Trade 任务相关的所有函数和界面组件
支持多用户并发:使用 gr.State 管理每个用户会话的状态
使用统一进度管理模块存储数据
"""
import json
import os
import numpy as np
from typing import List, Tuple, Optional, Dict, Any
import gradio as gr
# 导入统一进度管理模块
import progress_manager
# 导入 Trade 环境
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
tradeenv_path = os.path.join(current_dir, "TradeEnv")
if os.path.exists(tradeenv_path):
sys.path.insert(0, tradeenv_path)
from TradeEnv_v2 import TradeArenaEnv_Deterministic
# ------------------- 常量 -------------------
TRADE_MAX_STEPS = 120
# ------------------- Example Text -------------------
TRADE_EXAMPLE_TEXT = """
## 📖 Trading Environment Usage Instructions
### Scenario Description
You are a stock trader who needs to perform buy and sell operations across multiple trading days to maximize returns within 120 days.
### Important Concepts
- **S0, S1**: Stock codes (Stocks), representing 2 different stocks that can be bought and sold
- **F0, F1**: Market factors (Factors), representing market factors that affect stock prices
- News will report changes in these factors (e.g., "F0 rose slightly (+0.03)")
- Factor changes affect stock prices through a dependency matrix
- You need to predict stock price changes based on news and then trade
- Check news, for example "F0 rose slightly (+0.03) | F1 decreased significantly (-0.10)" to predict which stocks will rise/fall based on factor changes
- Buying is limited by cash
- Selling is limited by holdings
### Available Operations
- **Buy Stock**: Input positive number to buy (e.g., S0 input 100 means buy 100 shares of S0)
- **Sell Stock**: Input negative number to sell (e.g., S0 input -50 means sell 50 shares of S0)
- Buying is limited by cash, selling is limited by holdings
## Example
### Example Logic (Only shown in examples. In actual tasks, these rules are hidden and need to be inferred by users)
- The matrix corresponding to S0, S1, F0, F1 is [[0.1, 0.2], [-0.3, 0.4]]
- This means: if F0 rises by 1 point, S0 rises by 0.1 points; if F0 rises by 1 point, S1 falls by 0.3 points; if F1 rises by 1 point, S0 rises by 0.2 points; if F1 rises by 1 point, S1 rises by 0.4 points
### Initial Environment in This Example
- You have 100 cash
- S0 initial price is 1, S1 initial price is 2
- This example is a simple demonstration with only 2 days (actual task is 120 days)
### Example Steps
**Note: You need to discover the rules between stocks S and factors F yourself. The example below is from a god's-eye view to demonstrate how to use the rules**
1. **Step 1 (Day 1)**:
- Environment state before execution: Tomorrow F0 rose significantly (+0.10) | F1 rose slightly (+0.05)
- Stock prices before execution: S0 1.00, S1 2.00, Cash 100
- Action: Buy 100 shares of S0
- Reason: S0 tomorrow's price = 1.00 + (0.1×0.10) + (0.2×0.05) = 1.00 + 0.01 + 0.01 = 1.02 (up 2%), while S1 tomorrow's price is S1 = 2.00 + ((-0.3)×0.10) + (0.4×0.05) = 2.00 - 0.03 + 0.02 = 1.99 (down 0.5%). S0 rises while S1 falls, so buy S0. Buying 100 shares of S0 costs 100, cash becomes 0.
2. **Step 2 (Day 2)**:
- Environment state before execution: Tomorrow F0 decreased significantly (-0.15) | F1 rose significantly (+0.10)
- Stock prices before execution: S0 1.02, S1 1.99, Cash 0, Holdings 100 shares of S0
- Action: Sell 100 shares of S0, buy approximately 51 shares of S1
- Reason: S0 tomorrow's price = 1.02 + (0.1×(-0.15)) + (0.2×0.10) = 1.02 - 0.015 + 0.02 = 1.025 (slight rise 0.5%), while S1 tomorrow's price is S1 = 1.99 + ((-0.3)×(-0.15)) + (0.4×0.10) = 1.99 + 0.045 + 0.04 = 2.075 (up 4.3%). S1's rise is much greater than S0, so sell S0 and buy S1. Selling 100 shares of S0 yields 102, can buy approximately 51 shares of S1 (102/1.99≈51.26, rounded to 51 shares, cost about 101.49).
3. **Step 3 (Day 3)**:
- Environment state before execution: Tomorrow F0 stable (0.00) | F1 rose significantly (+0.20)
- Stock prices before execution: S0 1.025, S1 2.075, Cash 0.51, Holdings 51 shares of S1
- Action: No operation (or use remaining cash to buy a small amount of S1)
- Reason: S0 tomorrow's price = 1.025 + (0.1×0) + (0.2×0.20) = 1.025 + 0.04 = 1.065 (up 3.9%), while S1 tomorrow's price is S1 = 2.075 + ((-0.3)×0) + (0.4×0.20) = 2.075 + 0.08 = 2.155 (up 3.9%). Both stocks have similar percentage gains, but S1 has a larger absolute gain (0.08 vs 0.04), and we already hold S1, so maintain position.
### Final State: 51 shares of S1, price 2.155 per share, total value approximately 109.91 (51×2.155), plus remaining cash approximately 0.51, total value approximately 110.42, return rate approximately 10.42%
"""
# ------------------- 状态管理 -------------------
def create_trade_state() -> Dict[str, Any]:
"""创建初始的 Trade 任务状态(每个用户会话独立)"""
return {
'env': None, # TradeArenaEnv_Deterministic 实例
'test_data': [], # 测试数据
'current_env_idx': 0, # 当前环境索引
'history_records': [], # 操作历史记录
}
# ------------------- 工具函数 -------------------
def format_trade_state(obs: Dict[str, Any]) -> str:
"""格式化 Trade 环境状态显示"""
lines = []
lines.append(f"Trading day: {obs.get('day', 0)}")
lines.append(f"Cash: {obs.get('cash', 0):.2f}")
lines.append(f"Total value: {obs.get('total_value', 0):.2f}")
prices = obs.get('prices', {})
positions = obs.get('positions', {})
if prices:
lines.append("\nStock prices:")
for stock, price in prices.items():
pos = positions.get(stock, 0)
stock_value = pos * price
lines.append(f" {stock}: {price:.2f} (Holdings: {pos}, Total value: {stock_value:.2f})")
news = obs.get('news_next_day_text')
if news:
lines.append(f"\nNext day news: {news}")
return "\n".join(lines)
def format_trade_history_record(step_num: int, obs_before: Dict[str, Any], action_str: str, reward: float, total_value: float, error: str = None) -> str:
"""格式化单步历史记录
Args:
step_num: 步骤编号
obs_before: 执行动作前的观察(包含当天价格和新闻)
action_str: 动作字符串
reward: 奖励
total_value: 总价值
error: 错误信息(如果有)
"""
lines = []
day = obs_before.get('day', 0)
lines.append(f"Step {step_num} (Day {day}):")
# Current day stock prices
prices = obs_before.get('prices', {})
if prices:
lines.append("Current day stock prices:")
for stock, price in sorted(prices.items()):
lines.append(f" {stock}: {price:.2f}")
# Next day news
news = obs_before.get('news_next_day_text')
if news:
lines.append(f"Next day news: {news}")
else:
lines.append("Next day news: None")
# Action
if error:
lines.append(f"Action: {action_str} (invalid)")
lines.append(f"Feedback: ❌ {error}")
else:
lines.append(f"Action: {action_str}")
lines.append(f"Feedback: Reward={reward:.2f}, Total Value={total_value:.2f}")
return "\n".join(lines)
def load_trade_test_data(state: Dict[str, Any], current_dir: str) -> Tuple[Dict[str, Any], str]:
"""加载 Trade 测试数据"""
try:
# 加载所有测试文件
test_data = []
for i in range(1, 31): # 假设有30个测试文件
test_file = os.path.join(current_dir, f"test_data/trade/test_trade_config_{i}.json")
if not os.path.exists(test_file):
test_file = f"test_data/trade/test_trade_config_{i}.json"
if os.path.exists(test_file):
with open(test_file, 'r', encoding='utf-8') as f:
test_data.append(json.load(f))
state['test_data'] = test_data
return state, f"✅ Successfully loaded {len(test_data)} test environments"
except FileNotFoundError as e:
return state, f"❌ File not found: {str(e)}"
except Exception as e:
return state, f"❌ Load failed: {str(e)}"
def trade_save_progress_internal(state: Dict[str, Any], current_user_id: str, save_dir: str) -> str:
"""保存 Trade 环境进度(使用统一进度管理模块)"""
# Auto-generate user ID if not provided
if not current_user_id:
import uuid
current_user_id = f"user_{uuid.uuid4().hex[:8]}"
env = state.get('env')
if env is None:
return "⚠️ No progress to save"
try:
current_env_idx = state.get('current_env_idx', 0)
history_records = state.get('history_records', [])
test_data = state.get('test_data', [])
env_progress = {
"user_id": current_user_id,
"env_idx": current_env_idx,
"env_idx_display": current_env_idx + 1,
# 不再保存 config,因为可以从 test_data[env_idx] 获取
"day": env.t,
"cash": float(env.cash),
"positions": env.positions.tolist() if hasattr(env.positions, 'tolist') else list(env.positions),
"prices": env.prices.tolist() if hasattr(env.prices, 'tolist') else list(env.prices),
"variables_state": env.variables_state.tolist() if hasattr(env.variables_state, 'tolist') else list(env.variables_state),
"history": history_records,
"num_steps": len(history_records),
"done": env.t >= env.num_days,
"success": env.t >= env.num_days,
}
result = progress_manager.save_task_environment_progress(
current_user_id, save_dir, "trade", current_env_idx, env_progress
)
return f"✅ Progress saved (Environment {current_env_idx + 1}, Steps {len(history_records)})"
except Exception as e:
return f"❌ Save failed: {str(e)}"
def get_trade_stock_input_updates(env) -> List[Dict[str, Any]]:
"""根据环境中的股票数量,返回输入框的更新列表
Args:
env: TradeArenaEnv_Deterministic 环境实例,如果为 None 则隐藏所有输入框
Returns: 列表,包含10个 gr.update() 字典,用于更新输入框的可见性和标签
"""
MAX_STOCKS = 10
updates = []
if env is None or not hasattr(env, 'stocks'):
# 如果没有环境,隐藏所有输入框
return [gr.update(visible=False) for _ in range(MAX_STOCKS)]
stock_names = env.stocks # 从环境中获取实际的股票名称列表
for i in range(MAX_STOCKS):
if i < len(stock_names):
# 显示输入框,使用环境中的实际股票名称
actual_stock_name = stock_names[i]
updates.append(gr.update(visible=True, label=actual_stock_name))
else:
# 隐藏多余的输入框
updates.append(gr.update(visible=False))
return updates
def trade_load_environment(state: Dict[str, Any], env_idx_display: int, current_user_id: str, save_dir: str) -> Tuple[Dict[str, Any], str, str, str, str, str, str]:
"""加载 Trade 环境(使用统一进度管理模块)
Returns: (state, info, state_display, logic, history_display, progress, steps_info)
"""
# Auto-generate user ID if not provided
if not current_user_id:
import uuid
current_user_id = f"user_{uuid.uuid4().hex[:8]}"
test_data = state.get('test_data', [])
if not test_data:
return state, "❌ Please load test data first", "", "", "", "Click 'View Uncompleted Problems' button to view progress", "0 / 120"
env_idx = env_idx_display - 1
if env_idx < 0 or env_idx >= len(test_data):
return state, f"❌ Environment index out of range (1-{len(test_data)})", "", "", "", "Click 'View Unfinished Problems' button to view progress", "0 / 120"
# 使用统一进度管理模块检查是否有保存的进度
saved_progress_data = progress_manager.get_task_environment_progress(
current_user_id, save_dir, "trade", env_idx
)
# 如果有保存的进度,加载它
if saved_progress_data:
state['current_env_idx'] = env_idx
state['history_records'] = saved_progress_data.get("history", [])
num_steps = saved_progress_data.get("num_steps", len(state['history_records']))
# 从 test_data 获取 config(不再从保存的数据中获取,以节省存储空间)
# 为了向后兼容,如果保存的数据中有 config,优先使用(旧数据可能没有 test_data)
config = saved_progress_data.get("config")
if not config and env_idx < len(test_data):
config = test_data[env_idx]
if config:
state['env'] = TradeArenaEnv_Deterministic(config)
state['env'].t = saved_progress_data.get("day", 0)
state['env'].cash = saved_progress_data.get("cash", state['env'].initial_cash)
# 确保 positions 和 prices 是 numpy 数组
positions_data = saved_progress_data.get("positions", state['env'].positions.tolist() if hasattr(state['env'].positions, 'tolist') else list(state['env'].positions))
prices_data = saved_progress_data.get("prices", state['env'].prices.tolist() if hasattr(state['env'].prices, 'tolist') else list(state['env'].prices))
variables_state_data = saved_progress_data.get("variables_state", state['env'].variables_state.tolist() if hasattr(state['env'].variables_state, 'tolist') else list(state['env'].variables_state))
state['env'].positions = np.array(positions_data)
state['env'].prices = np.array(prices_data)
state['env'].variables_state = np.array(variables_state_data)
# 恢复下一天的新闻
day_key = f"day_{state['env'].t + 1}"
if day_key in config.get("timeline", {}):
state['env'].next_day_news = config["timeline"][day_key]
else:
state['env'].next_day_news = None
obs = state['env']._get_observation()
state_display = format_trade_state(obs)
history_display = "\n\n".join(state['history_records']) if state['history_records'] else "No history records"
info = f"✅ Environment {env_idx_display}/{len(test_data)} loaded\n"
info += f"Steps: {len(state['history_records'])}"
current_steps = len(state['history_records'])
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
# 注意:股票输入框的更新需要在主界面中处理,这里只返回环境信息
return state, info, state_display, "", history_display, "Click 'View Unfinished Problems' button to view progress", steps_info
# 没有保存的进度,初始化新环境
state['current_env_idx'] = env_idx
config = test_data[env_idx]
state['env'] = TradeArenaEnv_Deterministic(config)
state['history_records'] = []
trade_save_progress_internal(state, current_user_id, save_dir)
obs = state['env']._get_observation()
state_display = format_trade_state(obs)
history_display = "Environment initialized (new environment)\n"
info = f"✅ Environment {env_idx_display}/{len(test_data)} initialized (new environment)\n"
current_steps = len(state['history_records'])
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
# 注意:股票输入框的更新需要在主界面中处理,这里只返回环境信息
return state, info, state_display, "", history_display, "Click 'View Unfinished Problems' button to view progress", steps_info
def trade_step_environment_from_inputs(state: Dict[str, Any], stock_inputs: dict, current_user_id: str, save_dir: str) -> Tuple[Dict[str, Any], str, str, str, bool, str]:
"""从输入框执行 Trade 环境一步动作
Args:
state: 会话状态
stock_inputs: 股票操作输入框的字典 {stock_name: value},正数表示买入,负数表示卖出
注意:stock_name 应该是环境中的实际股票名称(如 "S0", "S1" 等)
Returns: (state, feedback, state_display, history_display, done, steps_info)
"""
# 构建动作字典
buy_dict = {}
sell_dict = {}
# 获取环境中的实际股票名称列表,用于验证输入
env = state.get('env')
valid_stocks = env.stocks if env else []
for stock, value in stock_inputs.items():
# 只处理有效的股票名称和有效的数值
if stock in valid_stocks and value is not None:
if value > 0:
# 正数表示买入
buy_dict[stock] = int(value)
elif value < 0:
# 负数表示卖出
sell_dict[stock] = int(abs(value))
env = state.get('env')
history_records = state.get('history_records', [])
# 如果没有操作,返回提示(但不报错,允许用户跳过这一轮)
if not buy_dict and not sell_dict:
if env is None:
return state, "❌ Please initialize environment first", "Please initialize environment first", "", False, "0 / 120"
# Auto-generate user ID if not provided
if not current_user_id:
import uuid
current_user_id = f"user_{uuid.uuid4().hex[:8]}"
current_steps = len(history_records) if history_records else 0
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
obs = env._get_observation()
current_state_display = format_trade_state(obs)
history_display = "\n\n".join(history_records) if history_records else ""
return state, "❌ Please enter user ID first", current_state_display, history_display, False, steps_info
# 检查是否已经达到步骤上限
current_steps = len(history_records) if history_records else 0
if current_steps >= TRADE_MAX_STEPS:
obs = env._get_observation()
current_state_display = format_trade_state(obs)
history_display = "\n\n".join(history_records) if history_records else ""
trade_save_progress_internal(state, current_user_id, save_dir)
feedback_info = f"⚠️ Reached step limit ({TRADE_MAX_STEPS} steps)\n"
feedback_info += "Task ended (failed to complete within the specified number of steps)\n"
feedback_info += "Cannot continue executing actions\n"
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, current_state_display, history_display, True, steps_info
# 允许不执行任何操作(跳过这一轮),但需要推进时间
action = {}
action_str = json.dumps(action, ensure_ascii=False)
try:
# 获取执行动作前的状态
obs_before = env._get_observation()
obs, reward, done, info = env.step(action)
state_display = format_trade_state(obs)
# 记录跳过操作
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, "Skip (no buy/sell operations)",
reward, obs.get('total_value', 0)
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
# 检查是否达到上限
if len(history_records) >= TRADE_MAX_STEPS:
done = True
trade_save_progress_internal(state, current_user_id, save_dir)
feedback_info = f"Action: No operation (skip)\nFeedback: Reward={reward:.2f}, Total Value={obs.get('total_value', 0):.2f}\n"
if done:
if env.t >= env.num_days:
feedback_info += "🎉 Task completed! All trading days ended!\n"
else:
feedback_info += f"⚠️ Task ended (reached step limit {TRADE_MAX_STEPS} steps)\n"
current_steps = len(history_records)
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, state_display, history_display, done, steps_info
except Exception as e:
obs = env._get_observation()
current_state_display = format_trade_state(obs)
history_display = "\n\n".join(history_records) if history_records else ""
current_steps = len(history_records) if history_records else 0
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, f"⚠️ No operation (all inputs are 0), but error occurred during execution: {str(e)}", current_state_display, history_display, False, steps_info
action = {}
if buy_dict:
action["buy"] = buy_dict
if sell_dict:
action["sell"] = sell_dict
# 转换为 JSON 字符串并调用原函数
action_str = json.dumps(action, ensure_ascii=False)
return trade_step_environment(state, action_str, current_user_id, save_dir)
def trade_step_environment(state: Dict[str, Any], action_str: str, current_user_id: str, save_dir: str) -> Tuple[Dict[str, Any], str, str, str, bool, str]:
"""执行 Trade 环境一步动作
Returns: (state, feedback, state_display, history_display, done, steps_info)
"""
env = state.get('env')
history_records = state.get('history_records', [])
current_state_display = ""
if env is not None:
obs = env._get_observation()
current_state_display = format_trade_state(obs)
if env is None:
return state, "❌ Please initialize environment first", current_state_display if current_state_display else "Please initialize environment first", "", False, "0 / 120"
# Auto-generate user ID if not provided
if not current_user_id:
import uuid
current_user_id = f"user_{uuid.uuid4().hex[:8]}"
# 获取执行动作前的状态
obs_before = env._get_observation()
# 解析动作
try:
action = json.loads(action_str.strip())
except json.JSONDecodeError:
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, action_str, 0, 0, "JSON format error"
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
done = False
if len(history_records) >= TRADE_MAX_STEPS:
done = True
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, action_str, 0, 0,
f"Reached step limit ({TRADE_MAX_STEPS} steps), task ended"
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
feedback_info = f"Action: {action_str}\nFeedback: ❌ JSON format error\n"
feedback_info += f"⚠️ Reached step limit ({TRADE_MAX_STEPS} steps)\n"
feedback_info += "Task ended (failed to complete within the specified number of steps)\n"
else:
feedback_info = f"Action: {action_str}\nFeedback: ❌ JSON format error\n"
trade_save_progress_internal(state, current_user_id, save_dir)
current_steps = len(history_records)
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, current_state_display, history_display, done, steps_info
# 检查是否达到步骤上限
if len(history_records) >= TRADE_MAX_STEPS:
history_display = "\n\n".join(history_records) if history_records else "" # 每步之间加空行
trade_save_progress_internal(state, current_user_id, save_dir)
feedback_info = f"⚠️ Reached step limit ({TRADE_MAX_STEPS} steps)\n"
feedback_info += "Task ended (failed to complete within the specified number of steps)\n"
feedback_info += "Cannot continue executing actions\n"
current_steps = len(history_records)
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, current_state_display, history_display, True, steps_info
# 执行动作
try:
obs, reward, done, info = env.step(action)
state_display = format_trade_state(obs)
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, action_str, reward, obs.get('total_value', 0)
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
if len(history_records) >= TRADE_MAX_STEPS:
done = True
if not (env.t >= env.num_days):
feedback_info = f"Action: {action_str}\nFeedback: Reward={reward:.2f}, Total Value={obs.get('total_value', 0):.2f}\n"
feedback_info += f"⚠️ Reached step limit ({TRADE_MAX_STEPS} steps), task ended (failed to complete all trading days within the specified number of steps)\n"
else:
feedback_info = f"Action: {action_str}\nFeedback: Reward={reward:.2f}, Total Value={obs.get('total_value', 0):.2f}\n"
feedback_info += "🎉 Task completed! All trading days ended!\n"
else:
feedback_info = f"Action: {action_str}\nFeedback: Reward={reward:.2f}, Total Value={obs.get('total_value', 0):.2f}\n"
if done:
feedback_info += "🎉 Task completed! All trading days ended!\n"
trade_save_progress_internal(state, current_user_id, save_dir)
current_steps = len(history_records)
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, state_display, history_display, done, steps_info
except Exception as e:
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, action_str, 0, 0, str(e)
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
done = False
if len(history_records) >= TRADE_MAX_STEPS:
done = True
step_num = len(history_records) + 1
history_record = format_trade_history_record(
step_num, obs_before, action_str, 0, 0,
f"Reached step limit ({TRADE_MAX_STEPS} steps), task ended"
)
history_records.append(history_record)
state['history_records'] = history_records
history_display = "\n\n".join(history_records) # 每步之间加空行
feedback_info = f"Action: {action_str}\nFeedback: ❌ {str(e)}\n"
feedback_info += f"⚠️ Reached step limit ({TRADE_MAX_STEPS} steps)\n"
feedback_info += "Task ended (failed to complete within the specified number of steps)\n"
else:
feedback_info = f"Action: {action_str}\nFeedback: ❌ {str(e)}\n"
trade_save_progress_internal(state, current_user_id, save_dir)
current_steps = len(history_records)
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, feedback_info, current_state_display, history_display, done, steps_info
def trade_reset_environment(state: Dict[str, Any], current_user_id: str, save_dir: str) -> Tuple[Dict[str, Any], str, str, str, str, str]:
"""重置 Trade 环境
Returns: (state, info, state_display, history_display, progress, steps_info)
"""
env = state.get('env')
if env is None:
return state, "❌ Please initialize environment first", "", "", "Click 'View Uncompleted Problems' button to view progress", "0 / 120"
env.reset()
state['history_records'] = []
trade_save_progress_internal(state, current_user_id, save_dir)
obs = env._get_observation()
state_display = format_trade_state(obs)
history_display = "Environment reset\n"
current_steps = len(state['history_records'])
steps_info = f"{current_steps} / {TRADE_MAX_STEPS}"
return state, "✅ Environment reset", state_display, history_display, "Click 'View Unfinished Problems' button to view progress", steps_info
def get_trade_current_env_idx(state: Dict[str, Any]) -> int:
"""获取当前 Trade 环境索引"""
return state.get('current_env_idx', 0)
def get_trade_test_data(state: Dict[str, Any]) -> List[dict]:
"""获取 Trade 测试数据"""
return state.get('test_data', [])
def get_trade_history_records(state: Dict[str, Any]) -> List[str]:
"""获取 Trade 历史记录"""
return state.get('history_records', [])
def get_trade_env(state: Dict[str, Any]):
"""获取 Trade 环境实例"""
return state.get('env', None)
def get_trade_progress_summary(state: Dict[str, Any], user_id: str, save_dir: str) -> str:
"""获取 Trade 任务用户进度摘要(使用统一进度管理模块)
Args:
state: 会话状态
user_id: 用户ID
save_dir: 保存目录
Returns: 格式化的进度摘要字符串
"""
# Auto-generate user ID if not provided
if not user_id or not user_id.strip():
import uuid
user_id = f"user_{uuid.uuid4().hex[:8]}"
user_id = user_id.strip()
test_data = state.get('test_data', [])
# 使用统一进度管理模块加载进度
task_data = progress_manager.load_task_progress(user_id, save_dir, "trade")
environments = task_data.get("environments", {})
completed_envs = set()
for env_key, progress_data in environments.items():
env_idx = progress_data.get("env_idx", -1)
done = progress_data.get("done", False)
success = progress_data.get("success", False)
num_steps = progress_data.get("num_steps", 0)
is_completed = False
if success or done:
is_completed = True
elif num_steps >= TRADE_MAX_STEPS:
is_completed = True
if is_completed:
completed_envs.add(env_idx)
total_envs = len(test_data) if test_data else 0
if total_envs == 0:
return "⚠️ Please load test data first"
all_env_indices = set(range(total_envs))
incomplete_envs = sorted(all_env_indices - completed_envs)
summary_lines = []
summary_lines.append(f"📊 Trade Task - Progress Summary for User {user_id}")
summary_lines.append(f"Total environments: {total_envs}")
summary_lines.append(f"Completed: {len(completed_envs)}/{total_envs}")
summary_lines.append(f"Incomplete: {len(incomplete_envs)}/{total_envs}")
if incomplete_envs:
summary_lines.append("\n❌ Incomplete environments:")
for i in range(0, len(incomplete_envs), 5):
env_display_list = [str(env_idx + 1) for env_idx in incomplete_envs[i:i+5]]
summary_lines.append(" " + ", ".join(env_display_list))
else:
summary_lines.append("\n🎉 Congratulations! All environments are completed!")
return "\n".join(summary_lines)
def create_trade_interface(current_dir: str, save_dir: str, user_id_input: gr.Textbox) -> Tuple:
"""创建 Trade 任务界面组件
Returns: (trade_interface, trade_env_idx_input, trade_init_btn, trade_reset_btn,
trade_env_info, trade_state_display, trade_steps_info_text,
trade_stock_inputs, trade_step_btn, trade_feedback_display, trade_history_display)
注意:环境控制组件(trade_env_idx_input, trade_init_btn, trade_reset_btn, trade_env_info)
需要在主界面中手动添加到进度摘要下方,不包含在 trade_interface 中。
为了保持函数签名一致,这里返回 None 作为占位符,主界面会忽略这些返回值。
"""
# 创建股票操作输入框(最多支持10只股票,根据环境动态显示),正数表示买入,负数表示卖出
trade_stock_inputs = {}
MAX_STOCKS = 10 # 支持最多10只股票
# 创建主界面 Row(不包含环境控制)
with gr.Row(visible=False) as trade_interface:
with gr.Column(scale=1):
trade_steps_info_text = gr.Textbox(
label="Steps Info",
value="0 / 120",
interactive=False,
visible=True,
lines=2
)
gr.Markdown("### 📜 Action History")
trade_history_display = gr.Textbox(
label="Action History",
interactive=False,
lines=10
)
with gr.Column(scale=1):
gr.Markdown("### 💹 Current Task State")
trade_state_display = gr.Textbox(
label="Market State",
interactive=False,
lines=10,
value="Please load environment first"
)
gr.Markdown("### 🎯 Trading Operations (Positive = Buy, Negative = Sell)")
# Create stock input boxes using multi-row layout to accommodate different numbers of stocks
# Display 4 input boxes per row, maximum 3 rows (12 total, but we only use 10)
for row in range(3): # Maximum 3 rows
with gr.Row():
for col in range(4): # 4 per row
idx = row * 4 + col
if idx < MAX_STOCKS:
stock_name = f"S{idx}"
trade_stock_inputs[stock_name] = gr.Number(
label=f"{stock_name}",
value=0,
precision=0,
step=1,
visible=False # Initially hidden, shown based on actual stock count after loading environment
)
trade_step_btn = gr.Button("Execute Trade", variant="primary")
# Environment feedback box removed, but keep variable for interface compatibility
trade_feedback_display = gr.Textbox(
label="Feedback Info",
interactive=False,
lines=5,
visible=False
)
# 返回占位符(主界面会使用自己创建的环境控制组件)
return (trade_interface, None, None, None,
None, trade_state_display, trade_steps_info_text,
trade_stock_inputs, trade_step_btn, trade_feedback_display, trade_history_display)