Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| import logging | |
| from tradingagents.graph.trading_graph import TradingAgentsGraph | |
| from tradingagents.default_config import DEFAULT_CONFIG | |
| # 配置日志输出格式 | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S" | |
| ) | |
| logger = logging.getLogger("TradingAgents") | |
| # ---------------------- 1. 初始化 TradingAgents 框架 ---------------------- | |
| def init_trading_agents(): | |
| """初始化框架,读取环境变量(API Key),配置LLM""" | |
| logger.info("开始初始化 TradingAgents 框架...") | |
| # 复制默认配置并修改 | |
| custom_config = DEFAULT_CONFIG.copy() | |
| logger.info("加载默认配置并开始自定义配置...") | |
| # 1. 显式配置LLM提供商和模型(与项目默认规范匹配) | |
| custom_config["llm_provider"] = "openai" # 显式指定提供商 | |
| custom_config["backend_url"] = "https://api.vveai.com/v1" # 对应API地址 | |
| custom_config["deep_think_llm"] = "gpt-4o-mini" # 深度思考模型 | |
| custom_config["quick_think_llm"] = "gpt-4o-mini" # 快速思考模型 | |
| logger.info(f"配置LLM模型:提供商={custom_config['llm_provider']}, 深度模型={custom_config['deep_think_llm']}") | |
| # 2. 启用在线数据(依赖FinnHub) | |
| custom_config["online_tools"] = True | |
| logger.info(f"启用在线工具: {custom_config['online_tools']}") | |
| # 3. 配置辩论轮次(缩短响应时间) | |
| custom_config["max_debate_rounds"] = 1 | |
| custom_config["max_risk_discuss_rounds"] = 1 # 补充风险辩论轮次配置 | |
| logger.info(f"配置辩论参数: 最大辩论轮次={custom_config['max_debate_rounds']}, 风险讨论轮次={custom_config['max_risk_discuss_rounds']}") | |
| # 4. 检查必要环境变量(避免API调用失败) | |
| required_env_vars = ["OPENAI_API_KEY"] | |
| if custom_config["online_tools"]: | |
| required_env_vars.append("FINNHUB_API_KEY") | |
| missing_vars = [var for var in required_env_vars if not os.getenv(var)] | |
| if missing_vars: | |
| error_msg = f"缺少必要环境变量:{', '.join(missing_vars)},请配置后重试" | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) | |
| logger.info("所有必要环境变量均已配置") | |
| # 5. 初始化框架(debug模式便于调试) | |
| logger.info("开始初始化 TradingAgentsGraph 实例...") | |
| ta = TradingAgentsGraph(debug=True, config=custom_config) | |
| logger.info("TradingAgents 框架初始化完成") | |
| return ta | |
| # 全局初始化(只执行一次) | |
| try: | |
| logger.info("===== 启动 TradingAgents 应用 =====") | |
| ta = init_trading_agents() | |
| except Exception as e: | |
| logger.critical(f"框架初始化失败: {str(e)}", exc_info=True) | |
| raise # 终止程序,初始化失败无法继续运行 | |
| # ---------------------- 2. 定义交易决策函数 ---------------------- | |
| def generate_trading_decision(ticker, analysis_date, progress=gr.Progress()): | |
| """核心函数:生成交易决策,显示中间分析过程""" | |
| logger.info(f"收到分析请求 - 股票代码: {ticker}, 分析日期: {analysis_date}") | |
| try: | |
| # 1. 验证输入(与之前逻辑一致) | |
| if not ticker.strip(): | |
| error_msg = "错误:股票代码不能为空!请输入如 AAPL、SPY 的代码。" | |
| logger.warning(error_msg) | |
| return error_msg | |
| if not analysis_date.strip(): | |
| error_msg = "错误:分析日期不能为空!请输入格式如 2024-05-10 的日期。" | |
| logger.warning(error_msg) | |
| return error_msg | |
| try: | |
| time.strptime(analysis_date.strip(), "%Y-%m-%d") | |
| except ValueError: | |
| error_msg = f"错误:日期格式无效!请使用 YYYY-MM-DD 格式,输入的日期为: {analysis_date}" | |
| logger.warning(error_msg) | |
| return error_msg | |
| logger.info("输入验证通过,开始分析流程") | |
| # 2. 显示进度 | |
| progress(0, desc="开始初始化分析...") | |
| time.sleep(1) | |
| progress(0.2, desc=f"获取 {ticker} 的金融数据(FinnHub)...") | |
| time.sleep(2) | |
| progress(0.5, desc="分析师团队分析(财务+情绪+技术指标)...") | |
| time.sleep(3) | |
| progress(0.8, desc="研究团队辩论+交易员决策+风险管理评估...") | |
| # 3. 调用框架核心方法,捕获中间结果 | |
| logger.info(f"调用 propagate 方法 - company_name={ticker.strip()}, trade_date={analysis_date.strip()}") | |
| start_time = time.time() | |
| # ---- 关键修改:调试模式下捕获流式中间结果 ---- | |
| all_outputs = [] # 存储所有中间输出 | |
| final_state = None | |
| processed_signal = None | |
| # 若处于调试模式(debug=True),使用 stream 获取中间结果 | |
| if ta.debug: | |
| # 获取 graph.invoke 所需的参数 | |
| args = ta.propagator.get_graph_args() | |
| init_state = ta.propagator.create_initial_state(ticker.strip(), analysis_date.strip()) | |
| # 遍历流式结果 | |
| for i, chunk in enumerate(ta.graph.stream(init_state, **args)): | |
| # 提取并格式化中间消息 | |
| if chunk.get("messages") and len(chunk["messages"]) > 0: | |
| msg = chunk["messages"][-1].content # 假设消息内容在 .content 中 | |
| all_outputs.append(f"## 中间步骤 {i+1}\n{msg}\n\n") | |
| logger.info(f"中间结果 {i+1}: {msg[:50]}...") # 日志简略显示 | |
| # 模拟进度(根据实际步骤数调整,这里简化) | |
| progress(0.8 + 0.2 * i / 10, desc=f"处理中间步骤 {i+1}/10") | |
| # 流式结束后,最终状态为最后一个 chunk | |
| final_state = chunk | |
| else: | |
| # 非调试模式,直接调用(无中间结果) | |
| final_state, processed_signal = ta.propagate( | |
| company_name=ticker.strip(), | |
| trade_date=analysis_date.strip() | |
| ) | |
| all_outputs.append("(非调试模式,无中间步骤显示)\n\n") | |
| # 处理最终结果 | |
| if final_state: | |
| decision = final_state.get("final_trade_decision", "未获取到最终交易决策") | |
| all_outputs.append(f"# 最终交易决策\n{decision}\n\n") | |
| # 添加免责声明 | |
| disclaimer = "\n\n【免责声明】本结果仅用于研究目的,不构成任何财务、投资或交易建议。交易风险自负。" | |
| final_result = f"# {ticker} 交易决策报告(分析日期:{analysis_date})\n\n" + "".join(all_outputs) + disclaimer | |
| end_time = time.time() | |
| logger.info(f"分析完成,耗时 {end_time - start_time:.2f} 秒") | |
| progress(1.0, desc="分析完成!") | |
| return final_result | |
| except Exception as e: | |
| error_msg = ( | |
| f"分析失败:{str(e)}\n\n可能原因:\n" | |
| "1. API Key 无效或已过期\n" | |
| "2. 股票代码不存在\n" | |
| "3. 分析日期格式错误(需 YYYY-MM-DD)\n" | |
| "4. FinnHub/OpenAI API 调用超限" | |
| ) | |
| logger.error(f"分析过程出错: {str(e)}", exc_info=True) | |
| return error_msg | |
| # ---------------------- 3. 构建 Gradio 界面 ---------------------- | |
| logger.info("开始构建 Gradio 界面...") | |
| with gr.Blocks(title="TradingAgents 金融交易决策工具") as demo: | |
| gr.Markdown("# TradingAgents 多智能体 LLM 金融交易工具") | |
| gr.Markdown("基于大语言模型的多智能体协作分析,支持股票财务、情绪、技术指标综合评估") | |
| gr.Markdown("⚠️ 免费版 API 有调用限制,单次分析约 1-3 分钟,请耐心等待") | |
| with gr.Row(): | |
| ticker_input = gr.Textbox( | |
| label="股票代码", | |
| value="AAPL", | |
| placeholder="输入股票代码,如 AAPL(苹果)、SPY(标普 500 ETF)", | |
| interactive=True | |
| ) | |
| date_input = gr.Textbox( | |
| label="分析日期", | |
| value="2024-05-10", | |
| placeholder="输入格式:YYYY-MM-DD(如 2024-05-10)", | |
| interactive=True | |
| ) | |
| submit_btn = gr.Button("生成交易决策", variant="primary") | |
| result_output = gr.Markdown( | |
| label="分析结果", | |
| value="点击上方按钮开始分析,结果将显示在这里..." | |
| ) | |
| submit_btn.click( | |
| fn=generate_trading_decision, | |
| inputs=[ticker_input, date_input], | |
| outputs=result_output | |
| ) | |
| logger.info("Gradio 界面构建完成") | |
| # ---------------------- 4. 启动界面 ---------------------- | |
| if __name__ == "__main__": | |
| logger.info("启动 Gradio 服务...") | |
| try: | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| logger.info("Gradio 服务已启动") | |
| except Exception as e: | |
| logger.critical(f"Gradio 服务启动失败: {str(e)}", exc_info=True) |