JC321 commited on
Commit
2b5eb7d
·
verified ·
1 Parent(s): c824ed2

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -251
app.py CHANGED
@@ -1,251 +1,394 @@
1
- """
2
- SEC Financial Data MCP Server with Gradio UI
3
- Provides standard MCP service + Web interface for testing
4
- """
5
- import os
6
- import gradio as gr
7
- import json
8
- from mcp.server.fastmcp import FastMCP
9
- from edgar_client import EdgarDataClient
10
- from financial_analyzer import FinancialAnalyzer
11
-
12
- # Initialize FastMCP server
13
- mcp = FastMCP("sec-financial-data")
14
-
15
- # Initialize EDGAR clients
16
- edgar_client = EdgarDataClient(
17
- user_agent="Juntao Peng Financial Report Metrics App (jtyxabc@gmail.com)"
18
- )
19
-
20
- financial_analyzer = FinancialAnalyzer(
21
- user_agent="Juntao Peng Financial Report Metrics App (jtyxabc@gmail.com)"
22
- )
23
-
24
- # Define MCP tools
25
- @mcp.tool()
26
- def search_company(company_name: str) -> dict:
27
- """Search for a company by name in SEC EDGAR database."""
28
- result = edgar_client.search_company_by_name(company_name)
29
- return result if result else {"error": f"No company found with name: {company_name}"}
30
-
31
- @mcp.tool()
32
- def get_company_info(cik: str) -> dict:
33
- """Get detailed company information."""
34
- result = edgar_client.get_company_info(cik)
35
- return result if result else {"error": f"No company found with CIK: {cik}"}
36
-
37
- @mcp.tool()
38
- def get_company_filings(cik: str, form_types: list[str] | None = None) -> dict:
39
- """Get list of company SEC filings."""
40
- form_types_tuple = tuple(form_types) if form_types else None
41
- result = edgar_client.get_company_filings(cik, form_types_tuple)
42
- if result:
43
- limited_result = result[:20]
44
- return {
45
- "total": len(result),
46
- "returned": len(limited_result),
47
- "filings": limited_result
48
- }
49
- return {"error": f"No filings found for CIK: {cik}"}
50
-
51
- @mcp.tool()
52
- def get_financial_data(cik: str, period: str) -> dict:
53
- """Get financial data for a specific period."""
54
- result = edgar_client.get_financial_data_for_period(cik, period)
55
- return result if result and "period" in result else {"error": f"No financial data found for CIK: {cik}, Period: {period}"}
56
-
57
- @mcp.tool()
58
- def extract_financial_metrics(cik: str, years: int = 3) -> dict:
59
- """Extract comprehensive financial metrics for multiple years."""
60
- if years < 1 or years > 10:
61
- return {"error": "Years parameter must be between 1 and 10"}
62
-
63
- metrics = financial_analyzer.extract_financial_metrics(cik, years)
64
- if metrics:
65
- formatted = financial_analyzer.format_financial_data(metrics)
66
- return {"periods": len(formatted), "data": formatted}
67
- return {"error": f"No financial metrics extracted for CIK: {cik}"}
68
-
69
- @mcp.tool()
70
- def get_latest_financial_data(cik: str) -> dict:
71
- """Get the most recent financial data available."""
72
- result = financial_analyzer.get_latest_financial_data(cik)
73
- return result if result and "period" in result else {"error": f"No latest financial data found for CIK: {cik}"}
74
-
75
- @mcp.tool()
76
- def advanced_search_company(company_input: str) -> dict:
77
- """Advanced search supporting both company name and CIK code."""
78
- result = financial_analyzer.search_company(company_input)
79
- return result if not result.get("error") else {"error": result["error"]}
80
-
81
- # Gradio wrapper functions (添加调试和超时处理)
82
- def gradio_search_company(company_name: str):
83
- """Gradio wrapper for search_company"""
84
- if not company_name or not company_name.strip():
85
- return json.dumps({"error": "Company name cannot be empty"}, indent=2)
86
- try:
87
- import sys
88
- print(f"[DEBUG] Searching company: {company_name.strip()}", file=sys.stderr)
89
- result = search_company(company_name.strip())
90
- print(f"[DEBUG] Search result type: {type(result)}", file=sys.stderr)
91
- print(f"[DEBUG] Search result: {result}", file=sys.stderr)
92
- # Ensure result is a dict
93
- if not isinstance(result, dict):
94
- result = {"error": f"Unexpected result type: {type(result)}"}
95
- return json.dumps(result, indent=2)
96
- except TimeoutError as e:
97
- return json.dumps({"error": f"Request timeout: {str(e)}"}, indent=2)
98
- except Exception as e:
99
- import traceback
100
- traceback.print_exc()
101
- return json.dumps({"error": f"Exception: {str(e)}", "type": str(type(e))}, indent=2)
102
-
103
- def gradio_get_company_info(cik: str):
104
- """Gradio wrapper for get_company_info"""
105
- if not cik or not cik.strip():
106
- return json.dumps({"error": "CIK cannot be empty"}, indent=2)
107
- try:
108
- import sys
109
- print(f"[DEBUG] Getting company info for CIK: {cik.strip()}", file=sys.stderr)
110
- result = get_company_info(cik.strip())
111
- print(f"[DEBUG] Company info result: {result}", file=sys.stderr)
112
- if not isinstance(result, dict):
113
- result = {"error": f"Unexpected result type: {type(result)}"}
114
- return json.dumps(result, indent=2)
115
- except TimeoutError as e:
116
- return json.dumps({"error": f"Request timeout: {str(e)}"}, indent=2)
117
- except Exception as e:
118
- import traceback
119
- traceback.print_exc()
120
- return json.dumps({"error": f"Exception: {str(e)}", "type": str(type(e))}, indent=2)
121
-
122
- def gradio_extract_metrics(cik: str, years: float):
123
- """Gradio wrapper for extract_financial_metrics"""
124
- if not cik or not cik.strip():
125
- return json.dumps({"error": "CIK cannot be empty"}, indent=2)
126
- try:
127
- import sys
128
- years_int = int(years)
129
- print(f"[DEBUG] Extracting metrics for CIK: {cik.strip()}, Years: {years_int}", file=sys.stderr)
130
- result = extract_financial_metrics(cik.strip(), years_int)
131
- print(f"[DEBUG] Extract metrics result: {result}", file=sys.stderr)
132
- if not isinstance(result, dict):
133
- result = {"error": f"Unexpected result type: {type(result)}"}
134
- return json.dumps(result, indent=2)
135
- except TimeoutError as e:
136
- return json.dumps({"error": f"Request timeout: {str(e)}"}, indent=2)
137
- except Exception as e:
138
- import traceback
139
- traceback.print_exc()
140
- return json.dumps({"error": f"Exception: {str(e)}", "type": str(type(e))}, indent=2)
141
-
142
- def gradio_get_latest(cik: str):
143
- """Gradio wrapper for get_latest_financial_data"""
144
- if not cik or not cik.strip():
145
- return json.dumps({"error": "CIK cannot be empty"}, indent=2)
146
- try:
147
- import sys
148
- print(f"[DEBUG] Getting latest data for CIK: {cik.strip()}", file=sys.stderr)
149
- result = get_latest_financial_data(cik.strip())
150
- print(f"[DEBUG] Latest data result: {result}", file=sys.stderr)
151
- if not isinstance(result, dict):
152
- result = {"error": f"Unexpected result type: {type(result)}"}
153
- return json.dumps(result, indent=2)
154
- except TimeoutError as e:
155
- return json.dumps({"error": f"Request timeout: {str(e)}"}, indent=2)
156
- except Exception as e:
157
- import traceback
158
- traceback.print_exc()
159
- return json.dumps({"error": f"Exception: {str(e)}", "type": str(type(e))}, indent=2)
160
-
161
- # Create Gradio interface
162
- with gr.Blocks(title="SEC Financial Data MCP Server", theme=gr.themes.Soft()) as demo:
163
- gr.Markdown("""
164
- # 📊 SEC Financial Data MCP Server
165
-
166
- Access real-time SEC EDGAR financial data via Model Context Protocol
167
-
168
- **MCP Endpoint:** Use `mcp_server_fastmcp.py` for standard MCP client connections
169
- """)
170
-
171
- with gr.Tab("🔍 Search Company"):
172
- gr.Markdown("### Search for a company by name")
173
- with gr.Row():
174
- with gr.Column():
175
- company_input = gr.Textbox(label="Company Name", placeholder="Tesla", value="Tesla")
176
- search_btn = gr.Button("Search", variant="primary")
177
- with gr.Column():
178
- search_output = gr.Code(label="Result", language="json", lines=15)
179
- search_btn.click(gradio_search_company, inputs=company_input, outputs=search_output)
180
-
181
- with gr.Tab("ℹ️ Company Info"):
182
- gr.Markdown("### Get detailed company information")
183
- with gr.Row():
184
- with gr.Column():
185
- cik_input = gr.Textbox(label="Company CIK", placeholder="0001318605", value="0001318605")
186
- info_btn = gr.Button("Get Info", variant="primary")
187
- with gr.Column():
188
- info_output = gr.Code(label="Result", language="json", lines=15)
189
- info_btn.click(gradio_get_company_info, inputs=cik_input, outputs=info_output)
190
-
191
- with gr.Tab("📈 Financial Metrics"):
192
- gr.Markdown("### Extract multi-year financial metrics ⭐")
193
- with gr.Row():
194
- with gr.Column():
195
- metrics_cik = gr.Textbox(label="Company CIK", placeholder="0001318605", value="0001318605")
196
- metrics_years = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Years")
197
- metrics_btn = gr.Button("Extract Metrics", variant="primary")
198
- with gr.Column():
199
- metrics_output = gr.Code(label="Result", language="json", lines=20)
200
- metrics_btn.click(gradio_extract_metrics, inputs=[metrics_cik, metrics_years], outputs=metrics_output)
201
-
202
- with gr.Tab("🆕 Latest Data"):
203
- gr.Markdown("### Get latest financial data")
204
- with gr.Row():
205
- with gr.Column():
206
- latest_cik = gr.Textbox(label="Company CIK", placeholder="0001318605", value="0001318605")
207
- latest_btn = gr.Button("Get Latest", variant="primary")
208
- with gr.Column():
209
- latest_output = gr.Code(label="Result", language="json", lines=15)
210
- latest_btn.click(gradio_get_latest, inputs=latest_cik, outputs=latest_output)
211
-
212
- with gr.Tab("📖 Documentation"):
213
- gr.Markdown("""
214
- ## 🛠️ Available Tools (7)
215
-
216
- 1. **search_company** - Search by company name
217
- 2. **get_company_info** - Get company details by CIK
218
- 3. **get_company_filings** - List SEC filings
219
- 4. **get_financial_data** - Get specific period data
220
- 5. **extract_financial_metrics** ⭐ - Multi-year trends
221
- 6. **get_latest_financial_data** - Latest snapshot
222
- 7. **advanced_search_company** - Flexible search
223
-
224
- ## 🔗 MCP Integration
225
-
226
- For MCP client integration, run:
227
- ```bash
228
- python mcp_server_fastmcp.py
229
- ```
230
-
231
- Then configure your MCP client (e.g., Claude Desktop):
232
- ```json
233
- {
234
- "mcpServers": {
235
- "sec-financial-data": {
236
- "command": "python",
237
- "args": ["path/to/mcp_server_fastmcp.py"]
238
- }
239
- }
240
- }
241
- ```
242
-
243
- ## 📊 Data Source
244
-
245
- - **SEC EDGAR API** - Official SEC data
246
- - **Financial Statements** - 10-K, 10-Q, 20-F forms
247
- - **XBRL Data** - Structured metrics
248
- """)
249
-
250
- if __name__ == "__main__":
251
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ import os
5
+ import warnings
6
+ from huggingface_hub import InferenceClient
7
+
8
+ # 抑制 asyncio 警告
9
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
10
+ os.environ['PYTHONWARNINGS'] = 'ignore'
11
+
12
+ # 如果在 GPU 环境但不需要 GPU,禁用 CUDA
13
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
14
+ os.environ['CUDA_VISIBLE_DEVICES'] = ''
15
+
16
+ # ========== MCP 工具简化定义(符合MCP协议标准) ==========
17
+ MCP_TOOLS = [
18
+ {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}},
19
+ {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}},
20
+ {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}},
21
+ {"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}},
22
+ {"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}},
23
+ {"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}}
24
+ ]
25
+
26
+ # ========== MCP 服务配置 ==========
27
+ MCP_SERVICES = {
28
+ "financial": {"url": "https://huggingface.co/spaces/JC321/EasyReportDataMCP", "type": "fastmcp"},
29
+ "market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"}
30
+ }
31
+
32
+ TOOL_ROUTING = {
33
+ "advanced_search_company": MCP_SERVICES["financial"],
34
+ "get_latest_financial_data": MCP_SERVICES["financial"],
35
+ "extract_financial_metrics": MCP_SERVICES["financial"],
36
+ "get_quote": MCP_SERVICES["market"],
37
+ "get_market_news": MCP_SERVICES["market"],
38
+ "get_company_news": MCP_SERVICES["market"]
39
+ }
40
+
41
+ # ========== 初始化 LLM 客户端 ==========
42
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
43
+ client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient()
44
+ print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq")
45
+ print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools")
46
+
47
+ # ========== Token 限制配置 ==========
48
+ # HuggingFace Inference API 实际限制约 8000-16000 tokens
49
+ # 为了安全,设置更低的限制
50
+ MAX_TOTAL_TOKENS = 6000 # 总上下文限制
51
+ MAX_TOOL_RESULT_CHARS = 1500 # 工具返回最大字符数 (增加到1500)
52
+ MAX_HISTORY_CHARS = 500 # 单条历史消息最大字符数
53
+ MAX_HISTORY_TURNS = 2 # 最大历史轮数
54
+ MAX_TOOL_ITERATIONS = 6 # 最大工具调用轮数 (增加到6,支持多工具调用)
55
+ MAX_OUTPUT_TOKENS = 2000 # 最大输出 tokens (增加到2000)
56
+
57
+ def estimate_tokens(text):
58
+ """估算文本 token 数量(粗略:1 token 2 字符)"""
59
+ return len(str(text)) // 2
60
+
61
+ def truncate_text(text, max_chars, suffix="...[truncated]"):
62
+ """截断文本到指定长度"""
63
+ text = str(text)
64
+ if len(text) <= max_chars:
65
+ return text
66
+ return text[:max_chars] + suffix
67
+
68
+ def get_system_prompt():
69
+ """生成包含当前日期的系统提示词(精简版)"""
70
+ from datetime import datetime
71
+ current_date = datetime.now().strftime("%Y-%m-%d")
72
+ return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise."""
73
+
74
+ # ============================================================
75
+ # MCP 服务调用核心代码区
76
+ # 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议
77
+ # ============================================================
78
+
79
+ def call_mcp_tool(tool_name, arguments):
80
+ """调用 MCP 工具"""
81
+ service_config = TOOL_ROUTING.get(tool_name)
82
+ if not service_config:
83
+ return {"error": f"Unknown tool: {tool_name}"}
84
+
85
+ try:
86
+ if service_config["type"] == "fastmcp":
87
+ return _call_fastmcp(service_config["url"], tool_name, arguments)
88
+ elif service_config["type"] == "gradio":
89
+ return _call_gradio_api(service_config["url"], tool_name, arguments)
90
+ else:
91
+ return {"error": "Unknown service type"}
92
+ except Exception as e:
93
+ return {"error": str(e)}
94
+
95
+
96
+ def _call_fastmcp(service_url, tool_name, arguments):
97
+ """FastMCP: 标准 MCP JSON-RPC"""
98
+ response = requests.post(
99
+ service_url,
100
+ json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1},
101
+ headers={"Content-Type": "application/json"},
102
+ timeout=30
103
+ )
104
+
105
+ if response.status_code != 200:
106
+ return {"error": f"HTTP {response.status_code}"}
107
+
108
+ data = response.json()
109
+
110
+ # 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON
111
+ if isinstance(data, dict) and "result" in data:
112
+ result = data["result"]
113
+ if isinstance(result, dict) and "content" in result:
114
+ content = result["content"]
115
+ if isinstance(content, list) and len(content) > 0:
116
+ first_item = content[0]
117
+ if isinstance(first_item, dict) and "text" in first_item:
118
+ try:
119
+ return json.loads(first_item["text"])
120
+ except (json.JSONDecodeError, TypeError):
121
+ return {"text": first_item["text"]}
122
+ return result
123
+ return data
124
+
125
+
126
+ def _call_gradio_api(service_url, tool_name, arguments):
127
+ """Gradio: SSE 流式协议"""
128
+ tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"}
129
+ gradio_fn = tool_map.get(tool_name)
130
+ if not gradio_fn:
131
+ return {"error": "No mapping"}
132
+
133
+ # 构造参数
134
+ if tool_name == "get_quote":
135
+ params = [arguments.get("symbol", "")]
136
+ elif tool_name == "get_market_news":
137
+ params = [arguments.get("category", "general")]
138
+ elif tool_name == "get_company_news":
139
+ params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")]
140
+ else:
141
+ params = []
142
+
143
+ # 提交请求
144
+ call_url = f"{service_url}/call/{gradio_fn}"
145
+ resp = requests.post(call_url, json={"data": params}, timeout=10)
146
+ if resp.status_code != 200:
147
+ return {"error": f"HTTP {resp.status_code}"}
148
+
149
+ event_id = resp.json().get("event_id")
150
+ if not event_id:
151
+ return {"error": "No event_id"}
152
+
153
+ # 获取结果 (SSE)
154
+ result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20)
155
+ if result_resp.status_code != 200:
156
+ return {"error": f"HTTP {result_resp.status_code}"}
157
+
158
+ # 解析 SSE
159
+ for line in result_resp.iter_lines():
160
+ if line and line.decode('utf-8').startswith('data: '):
161
+ try:
162
+ result_data = json.loads(line.decode('utf-8')[6:])
163
+ if isinstance(result_data, list) and len(result_data) > 0:
164
+ return {"text": result_data[0]}
165
+ except json.JSONDecodeError:
166
+ continue
167
+
168
+ return {"error": "No result"}
169
+
170
+ # ============================================================
171
+ # End of MCP 服务调用代码区
172
+ # ============================================================
173
+
174
+ def chatbot_response(message, history):
175
+ """AI 助手主函数(流式输出,性能优化)"""
176
+ try:
177
+ messages = [{"role": "system", "content": get_system_prompt()}]
178
+
179
+ # 添加历史(最近2轮) - 严格限制上下文长度
180
+ if history:
181
+ for item in history[-MAX_HISTORY_TURNS:]:
182
+ if isinstance(item, (list, tuple)) and len(item) == 2:
183
+ # 用户消息(不截断)
184
+ messages.append({"role": "user", "content": item[0]})
185
+
186
+ # 助手回复(严格截断)
187
+ assistant_msg = str(item[1])
188
+ if len(assistant_msg) > MAX_HISTORY_CHARS:
189
+ assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
190
+ messages.append({"role": "assistant", "content": assistant_msg})
191
+
192
+ messages.append({"role": "user", "content": message})
193
+
194
+ tool_calls_log = []
195
+
196
+ # LLM 调用循环(支持多轮工具调用)
197
+ final_response_content = None
198
+ for iteration in range(MAX_TOOL_ITERATIONS):
199
+ response = client.chat.completions.create(
200
+ model="Qwen/Qwen3-32B:groq",
201
+ messages=messages,
202
+ tools=MCP_TOOLS,
203
+ max_tokens=MAX_OUTPUT_TOKENS,
204
+ temperature=0.5,
205
+ tool_choice="auto",
206
+ stream=False
207
+ )
208
+
209
+ choice = response.choices[0]
210
+
211
+ if choice.message.tool_calls:
212
+ messages.append(choice.message)
213
+
214
+ for tool_call in choice.message.tool_calls:
215
+ tool_name = tool_call.function.name
216
+ try:
217
+ tool_args = json.loads(tool_call.function.arguments)
218
+ except json.JSONDecodeError:
219
+ tool_args = {}
220
+
221
+ # 调用 MCP 工具
222
+ tool_result = call_mcp_tool(tool_name, tool_args)
223
+
224
+ # 检查错误
225
+ if isinstance(tool_result, dict) and "error" in tool_result:
226
+ # 工具调用失败,记录错误
227
+ tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
228
+ result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
229
+ else:
230
+ # 限制返回结果大小
231
+ result_str = json.dumps(tool_result, ensure_ascii=False)
232
+
233
+ if len(result_str) > MAX_TOOL_RESULT_CHARS:
234
+ if isinstance(tool_result, dict) and "text" in tool_result:
235
+ truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
236
+ tool_result_truncated = {"text": truncated_text, "_truncated": True}
237
+ elif isinstance(tool_result, dict):
238
+ truncated = {}
239
+ char_count = 0
240
+ for k, v in list(tool_result.items())[:8]: # 保留前8个字段
241
+ v_str = str(v)[:300] # 每个值最多300字符
242
+ truncated[k] = v_str
243
+ char_count += len(k) + len(v_str)
244
+ if char_count > MAX_TOOL_RESULT_CHARS:
245
+ break
246
+ tool_result_truncated = {**truncated, "_truncated": True}
247
+ else:
248
+ tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
249
+ result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
250
+ else:
251
+ result_for_llm = result_str
252
+
253
+ # 记录成功的工具调用
254
+ tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
255
+
256
+ messages.append({
257
+ "role": "tool",
258
+ "name": tool_name,
259
+ "content": result_for_llm,
260
+ "tool_call_id": tool_call.id
261
+ })
262
+
263
+ continue
264
+ else:
265
+ # 没有更多工具调用,保存最终答案
266
+ final_response_content = choice.message.content
267
+ break
268
+
269
+ # 构建响应前缀(简化版)
270
+ response_prefix = ""
271
+
272
+ # 显示工具调用(使用原生HTML details标签)
273
+ if tool_calls_log:
274
+ response_prefix += """<div style='margin-bottom: 15px;'>
275
+ <div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'>
276
+ 🛠️ Tools Used ({} calls)
277
+ </div>
278
+ """.format(len(tool_calls_log))
279
+
280
+ for idx, tool_call in enumerate(tool_calls_log):
281
+ # 预先计算 JSON 字符串,避免重复调用
282
+ args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
283
+ result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
284
+ result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
285
+
286
+ # 显示错误状态
287
+ error_indicator = " ❌ Error" if tool_call.get('error') else ""
288
+
289
+ # 使用原生 HTML5 details/summary 标签(不需要 JavaScript)
290
+ response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
291
+ <summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
292
+ <div style='display: flex; justify-content: space-between; align-items: center;'>
293
+ <div style='flex: 1;'>
294
+ <strong style='color: #2c5aa0;'>📌 {idx+1}. {tool_call['name']}{error_indicator}</strong>
295
+ <div style='font-size: 0.85em; color: #666; margin-top: 4px;'>📥 Input: <code style='background: #f5f5f5; padding: 2px 6px; border-radius: 3px;'>{args_json}</code></div>
296
+ </div>
297
+ <span style='font-size: 1.2em; color: #999; margin-left: 10px;'>▶</span>
298
+ </div>
299
+ </summary>
300
+ <div style='background: #f9f9f9; padding: 12px; border-top: 1px solid #eee;'>
301
+ <div style='font-size: 0.9em; color: #333;'>
302
+ <strong>📤 Output:</strong>
303
+ <pre style='background: #fff; padding: 10px; border-radius: 4px; overflow-x: auto; margin-top: 6px; font-size: 0.85em; border: 1px solid #e0e0e0; max-height: 400px; white-space: pre-wrap;'>{result_preview}</pre>
304
+ </div>
305
+ </div>
306
+ </details>
307
+ """
308
+
309
+ response_prefix += """</div>
310
+
311
+ ---
312
+
313
+ """
314
+ response_prefix += "\n"
315
+
316
+ # 流式输出最终答案
317
+ yield response_prefix
318
+
319
+ # 如果已经有最终答案,直接输出
320
+ if final_response_content:
321
+ # 已经从循环中获得了最终答案,直接输出
322
+ yield response_prefix + final_response_content
323
+ else:
324
+ # 如果循环结束但没有最终答案(达到最大迭代次数),需要再调用一次让模型总结
325
+ try:
326
+ stream = client.chat.completions.create(
327
+ model="Qwen/Qwen3-32B:groq",
328
+ messages=messages,
329
+ tools=None, # 不再允许调用工具
330
+ max_tokens=MAX_OUTPUT_TOKENS,
331
+ temperature=0.5,
332
+ stream=True
333
+ )
334
+
335
+ accumulated_text = ""
336
+ for chunk in stream:
337
+ if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
338
+ accumulated_text += chunk.choices[0].delta.content
339
+ yield response_prefix + accumulated_text
340
+ except Exception as stream_error:
341
+ # 流式输出失败,尝试非流式
342
+ final_resp = client.chat.completions.create(
343
+ model="Qwen/Qwen3-32B:groq",
344
+ messages=messages,
345
+ tools=None,
346
+ max_tokens=MAX_OUTPUT_TOKENS,
347
+ temperature=0.5,
348
+ stream=False
349
+ )
350
+ yield response_prefix + final_resp.choices[0].message.content
351
+
352
+ except Exception as e:
353
+ import traceback
354
+ error_detail = str(e)
355
+ if "500" in error_detail:
356
+ yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}"
357
+ else:
358
+ yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"
359
+
360
+ # ========== Gradio 界面(极简版)==========
361
+ with gr.Blocks(title="Financial AI Assistant") as demo:
362
+ gr.Markdown("# 💬 Financial AI Assistant")
363
+
364
+ chat = gr.ChatInterface(
365
+ fn=chatbot_response,
366
+ examples=[
367
+ "What's Apple's latest revenue and profit?",
368
+ "Show me NVIDIA's 3-year financial trends",
369
+ "How is Tesla's stock performing today?",
370
+ "Get the latest market news about crypto",
371
+ "Compare Microsoft's latest earnings with its current stock price",
372
+ ],
373
+ chatbot=gr.Chatbot(height=600),
374
+ )
375
+
376
+ # 启动应用
377
+ if __name__ == "__main__":
378
+ import sys
379
+
380
+ # 修复 asyncio 事件循环问题
381
+ if sys.platform == 'linux':
382
+ try:
383
+ import asyncio
384
+ asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
385
+ except:
386
+ pass
387
+
388
+ demo.launch(
389
+ server_name="0.0.0.0",
390
+ server_port=7860,
391
+ show_error=True,
392
+ ssr_mode=False,
393
+ quiet=False
394
+ )