JC321 commited on
Commit
ae04bbf
·
verified ·
1 Parent(s): 7c66f9b

Upload chat_direct.py

Browse files
Files changed (1) hide show
  1. chat_direct.py +1085 -0
chat_direct.py ADDED
@@ -0,0 +1,1085 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Financial AI Assistant - Direct Method Library (不依赖 HTTP)
3
+ 直接导入并调用 easy_financial_mcp.py 中的函数
4
+ 支持本地和 HF Space 部署
5
+ """
6
+
7
+ import sys
8
+ from pathlib import Path
9
+ import os
10
+ import json
11
+ from dotenv import load_dotenv
12
+ from huggingface_hub import InferenceClient
13
+ import requests
14
+ import warnings
15
+
16
+ # 抣削 asyncio 警告
17
+ warnings.filterwarnings('ignore', category=DeprecationWarning)
18
+ os.environ['PYTHONWARNINGS'] = 'ignore'
19
+
20
+ # 先加载 .env 文件
21
+ load_dotenv()
22
+
23
+ # 添加服务模块路径
24
+ PROJECT_ROOT = Path(__file__).parent.parent.absolute()
25
+ sys.path.insert(0, str(PROJECT_ROOT))
26
+
27
+ # 直接导入 MCP 中定义的函数
28
+ try:
29
+ from EasyFinancialAgent.easy_financial_mcp import (
30
+ search_company as _search_company,
31
+ get_company_info as _get_company_info,
32
+ get_company_filings as _get_company_filings,
33
+ get_financial_data as _get_financial_data,
34
+ extract_financial_metrics as _extract_financial_metrics,
35
+ get_latest_financial_data as _get_latest_financial_data,
36
+ advanced_search_company as _advanced_search_company
37
+ )
38
+ MCP_DIRECT_AVAILABLE = True
39
+ print("[FinancialAI] ✓ Direct MCP functions imported successfully")
40
+ except ImportError as e:
41
+ MCP_DIRECT_AVAILABLE = False
42
+ print(f"[FinancialAI] ✗ Failed to import MCP functions: {e}")
43
+ # 定义占位符函数
44
+ def _advanced_search_company(x):
45
+ return {"error": "MCP not available"}
46
+ def _get_company_info(x):
47
+ return {"error": "MCP not available"}
48
+ def _get_company_filings(x, y=None):
49
+ return {"error": "MCP not available"}
50
+ def _get_financial_data(x, y):
51
+ return {"error": "MCP not available"}
52
+ def _get_latest_financial_data(x):
53
+ return {"error": "MCP not available"}
54
+ def _extract_financial_metrics(x, y=3):
55
+ return {"error": "MCP not available"}
56
+
57
+
58
+ # ============================================================
59
+ # 便捷方法 - 公司搜索相关
60
+ # ============================================================
61
+
62
+ def search_company_direct(company_input):
63
+ """
64
+ 批量搜索公司信息(直接调用)
65
+
66
+ 使用 advanced_search_company 工具,支持公司名称、Ticker 或 CIK 代码
67
+
68
+ Args:
69
+ company_input: 公司名称、Ticker 代码或 CIK 代码
70
+
71
+ Returns:
72
+ 批量搜索结果
73
+
74
+ Example:
75
+ result = search_company_direct("Apple")
76
+ result = search_company_direct("AAPL")
77
+ result = search_company_direct("0000320193")
78
+ """
79
+ if not MCP_DIRECT_AVAILABLE:
80
+ return {"error": "MCP functions not available"}
81
+
82
+ try:
83
+ result = _advanced_search_company(company_input)
84
+ return [result]
85
+ except Exception as e:
86
+ return {"error": str(e)}
87
+
88
+
89
+ def get_company_info_direct(cik):
90
+ """
91
+ 获取公司详细信息(直接调用)
92
+
93
+ Args:
94
+ cik: 公司 CIK 代码
95
+
96
+ Returns:
97
+ 公司信息
98
+
99
+ Example:
100
+ result = get_company_info_direct("0000320193")
101
+ """
102
+ if not MCP_DIRECT_AVAILABLE:
103
+ return {"error": "MCP functions not available"}
104
+
105
+ try:
106
+ return _get_company_info(cik)
107
+ except Exception as e:
108
+ return {"error": str(e)}
109
+
110
+
111
+ def get_company_filings_direct(cik):
112
+ """
113
+ 获取公司 SEC 文件列表(直接调用)
114
+
115
+ Args:
116
+ cik: 公司 CIK 代码
117
+
118
+ Returns:
119
+ 文件列表
120
+
121
+ Example:
122
+ result = get_company_filings_direct("0000320193")
123
+ """
124
+ if not MCP_DIRECT_AVAILABLE:
125
+ return {"error": "MCP functions not available"}
126
+
127
+ try:
128
+ return _get_company_filings(cik)
129
+ except Exception as e:
130
+ return {"error": str(e)}
131
+
132
+
133
+ def advanced_search_company_detailed(company_input):
134
+ """
135
+ 高级公司搜索 - 支持公司名称、Ticker 或 CIK 的强大搜索方法
136
+
137
+ 不同于 search_company_direct,该方法来自 EasyReportDataMCP 中的 mcp_server_fastmcp
138
+ 更具有灵活性,可以自动检测输入的类型
139
+
140
+ Args:
141
+ company_input: 公司名称 ("Tesla", "Apple Inc")
142
+ Ticker 代码 ("TSLA", "AAPL", "MSFT")
143
+ CIK 代码 ("0001318605", "0000320193")
144
+
145
+ Returns:
146
+ dict: 包含以下信息:
147
+ - cik: 公司的 Central Index Key
148
+ - name: 办公室注册名称
149
+ - tickers: 股票代码
150
+ - sic: Standard Industrial Classification 代码
151
+ - sic_description: 行业/行业描述
152
+
153
+ Example:
154
+ # 按公司名称搜索
155
+ result = advanced_search_company_detailed("Tesla")
156
+ # 按 Ticker 搜索
157
+ result = advanced_search_company_detailed("TSLA")
158
+ # 按 CIK 搜索
159
+ result = advanced_search_company_detailed("0001318605")
160
+ """
161
+ if not MCP_DIRECT_AVAILABLE:
162
+ return {"error": "MCP functions not available"}
163
+
164
+ try:
165
+ # 直接调用 advanced_search_company 工具
166
+ result = _advanced_search_company(company_input)
167
+ return result
168
+ except Exception as e:
169
+ import traceback
170
+ return {
171
+ "error": str(e),
172
+ "traceback": traceback.format_exc()
173
+ }
174
+
175
+
176
+ def format_search_result(search_result):
177
+ """
178
+ 提取并格式化搜索结果
179
+
180
+ 将 advanced_search_company 的结果转换为标准格式:
181
+ [{company_name: str, cik: str, ticker: str}]
182
+
183
+ Args:
184
+ search_result: advanced_search_company 的返回结果
185
+ 格式: {'cik': '...', 'name': '...', 'tickers': [...], ...}
186
+
187
+ Returns:
188
+ list[dict]: 格式化的结果
189
+ [
190
+ {
191
+ 'company_name': str, # 公司名称
192
+ 'cik': str, # CIK 代码
193
+ 'ticker': str # 第一个股票代码
194
+ }
195
+ ]
196
+
197
+ Example:
198
+ search_result = {'cik': '0001577552', 'name': 'Alibaba Group Holding Ltd', 'tickers': ['BABA'], '_source': 'company_tickers_cache'}
199
+ formatted = format_search_result(search_result)
200
+ # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
201
+ """
202
+ # 处理错误情况
203
+ if isinstance(search_result, dict) and 'error' in search_result:
204
+ return []
205
+
206
+ # 处理列表情况
207
+ if isinstance(search_result, list):
208
+ formatted_list = []
209
+ for item in search_result:
210
+ formatted_item = format_search_result(item)
211
+ formatted_list.extend(formatted_item)
212
+ return formatted_list
213
+
214
+ # 处理单个字典
215
+ if not isinstance(search_result, dict):
216
+ return []
217
+
218
+ try:
219
+ company_name = search_result.get('name', '')
220
+ cik = search_result.get('cik', '')
221
+ tickers = search_result.get('tickers', [])
222
+
223
+ # 取数组的第一个元素,或使用空字符串
224
+ ticker = tickers[0] if isinstance(tickers, list) and len(tickers) > 0 else ''
225
+
226
+ return [{
227
+ 'company_name': company_name,
228
+ 'cik': cik,
229
+ 'ticker': ticker
230
+ }]
231
+ except Exception as e:
232
+ return []
233
+
234
+
235
+ def search_and_format(company_input):
236
+ """
237
+ 搎合搜索并立即格式化结果
238
+
239
+ 一个一步到位的便法方法,执行搜索并格式化结果
240
+
241
+ Args:
242
+ company_input: 公司名称、Ticker 或 CIK
243
+
244
+ Returns:
245
+ list[dict]: 格式化的结果
246
+
247
+ Example:
248
+ result = search_and_format('BABA')
249
+ # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
250
+ """
251
+ # 执行搜索
252
+ search_result = advanced_search_company_detailed(company_input)
253
+
254
+ # 检查是否有错误
255
+ if isinstance(search_result, dict) and 'error' in search_result:
256
+ return []
257
+
258
+ # 格式化结果
259
+ return format_search_result(search_result)
260
+
261
+
262
+ # ============================================================
263
+ # 便捷方法 - 财务数据相关
264
+ # ============================================================
265
+
266
+ def get_latest_financial_data_direct(cik):
267
+ """
268
+ 获取公司最新财务数据(直接调用)
269
+
270
+ Args:
271
+ cik: 公司 CIK 代码
272
+
273
+ Returns:
274
+ 最新财务数据
275
+
276
+ Example:
277
+ result = get_latest_financial_data_direct("0000320193")
278
+ """
279
+ if not MCP_DIRECT_AVAILABLE:
280
+ return {"error": "MCP functions not available"}
281
+
282
+ try:
283
+ return _get_latest_financial_data(cik)
284
+ except Exception as e:
285
+ return {"error": str(e)}
286
+
287
+
288
+ def extract_financial_metrics_direct(cik, years=5):
289
+ """
290
+ 提取多年财务指标趋势(直接调用)
291
+
292
+ Args:
293
+ cik: 公司 CIK 代码
294
+ years: 年数(默认 3 年)
295
+
296
+ Returns:
297
+ 财务指标数据
298
+
299
+ Example:
300
+ result = extract_financial_metrics_direct("0000320193", years=5)
301
+ """
302
+ if not MCP_DIRECT_AVAILABLE:
303
+ return {"error": "MCP functions not available"}
304
+
305
+ try:
306
+ return _extract_financial_metrics(cik, years)
307
+ except Exception as e:
308
+ return {"error": str(e)}
309
+
310
+
311
+ # ============================================================
312
+ # 高级方法 - 综合查询
313
+ # ============================================================
314
+
315
+ def query_company_direct(company_input, get_filings=True, get_metrics=True):
316
+ """
317
+ 综合查询公司信息(直接调用)
318
+ 包括搜索、基本信息、文件列表和财务指标
319
+
320
+ Args:
321
+ company_input: 公司名称或代码
322
+ get_filings: 是否获取文件列表
323
+ get_metrics: 是否获取财务指标
324
+
325
+ Returns:
326
+ 综合结果字典,包含 search, company_info, filings, metrics
327
+
328
+ Example:
329
+ result = query_company_direct("Apple", get_filings=True, get_metrics=True)
330
+ """
331
+ from datetime import datetime
332
+
333
+ result = {
334
+ "timestamp": datetime.now().isoformat(),
335
+ "query_input": company_input,
336
+ "status": "success",
337
+ "data": {
338
+ "company_search": None,
339
+ "company_info": None,
340
+ "filings": None,
341
+ "metrics": None
342
+ },
343
+ "errors": []
344
+ }
345
+
346
+ if not MCP_DIRECT_AVAILABLE:
347
+ result["status"] = "error"
348
+ result["errors"].append("MCP functions not available")
349
+ return result
350
+
351
+ try:
352
+ # 1. 搜索公司
353
+ search_result = search_company_direct(company_input)
354
+ if "error" in search_result:
355
+ result["errors"].append(f"Search error: {search_result['error']}")
356
+ result["status"] = "error"
357
+ return result
358
+
359
+ result["data"]["company_search"] = search_result
360
+
361
+ # 从搜索结果提取 CIK
362
+ cik = None
363
+ if isinstance(search_result, dict):
364
+ cik = search_result.get("cik")
365
+ elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
366
+ # 从列表中获取第一个元素
367
+ try:
368
+ first_item = search_result[0] if isinstance(search_result, (list, tuple)) else None
369
+ if isinstance(first_item, dict):
370
+ cik = first_item.get("cik")
371
+ except (IndexError, TypeError):
372
+ pass
373
+
374
+ if not cik:
375
+ result["errors"].append("Could not extract CIK from search result")
376
+ result["status"] = "error"
377
+ return result
378
+
379
+ # 2. 获取公司信息
380
+ company_info = get_company_info_direct(cik)
381
+ if "error" not in company_info:
382
+ result["data"]["company_info"] = company_info
383
+ else:
384
+ result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
385
+
386
+ # 3. 获取文件列表
387
+ if get_filings:
388
+ filings = get_company_filings_direct(cik)
389
+ if "error" not in filings:
390
+ result["data"]["filings"] = filings
391
+ else:
392
+ result["errors"].append(f"Failed to get filings: {filings.get('error')}")
393
+
394
+ # 4. 获取财务指标
395
+ if get_metrics:
396
+ metrics = extract_financial_metrics_direct(cik, years=3)
397
+ if "error" not in metrics:
398
+ result["data"]["metrics"] = metrics
399
+ else:
400
+ result["errors"].append(f"Failed to get metrics: {metrics.get('error')}")
401
+
402
+ except Exception as e:
403
+ result["status"] = "error"
404
+ result["errors"].append(f"Exception: {str(e)}")
405
+ import traceback
406
+ result["errors"].append(traceback.format_exc())
407
+
408
+ return result
409
+
410
+
411
+ # ============================================================
412
+ # LLM 模型配置与初始化
413
+ # ============================================================
414
+
415
+ # 初始化 LLM 客户端
416
+ def _init_llm_client():
417
+ """初始化 LLM 客户端"""
418
+ global llm_client
419
+ hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
420
+ llm_client = None
421
+ try:
422
+ if hf_token:
423
+ llm_client = InferenceClient(api_key=hf_token)
424
+ print("[FinancialAI] ✓ LLM client initialized with HF_TOKEN")
425
+ return True
426
+ else:
427
+ print("[FinancialAI] ⚠ Warning: HF_TOKEN not found, LLM features disabled")
428
+ return False
429
+ except Exception as e:
430
+ print(f"[FinancialAI] ✗ Failed to initialize LLM client: {e}")
431
+ return False
432
+
433
+ # 全局 llm_client 变量
434
+ llm_client = None
435
+ _init_llm_client()
436
+
437
+
438
+ def get_system_prompt():
439
+ """生成系统提示词"""
440
+ from datetime import datetime
441
+ current_date = datetime.now().strftime("%Y-%m-%d")
442
+ return f"""You are a financial analysis expert. Today is {current_date}.
443
+ Your role:
444
+ - Analyze company financial data, reports, and market news
445
+ - Provide investment insights based on factual data
446
+ - Be concise, objective, and data-driven
447
+ - Always include disclaimers about market risks
448
+
449
+ ⚠️ IMPORTANT: You have a maximum of 5 tool calls. Choose the MOST RELEVANT tools carefully:
450
+ - Use 'advanced_search_company' ONLY if you need to find a company's CIK
451
+ - Use 'extract_financial_metrics' for comprehensive multi-year financial analysis (RECOMMENDED for most queries)
452
+ - Use 'get_latest_financial_data' for quick recent snapshot
453
+ - Use 'get_quote' for real-time stock price
454
+ - Use 'get_company_news' for company-specific news
455
+ - Use 'get_market_news' for general market trends
456
+
457
+ Prioritize the most important tools for the user's question. Avoid redundant calls.
458
+ Output should be in English."""
459
+
460
+
461
+ def analyze_company_with_llm(company_input, analysis_type="summary"):
462
+ """
463
+ 使用 LLM 分析公司信息
464
+
465
+ Args:
466
+ company_input: 公司名称或代码
467
+ analysis_type: 分析类型 ("summary", "investment", "risks")
468
+
469
+ Returns:
470
+ LLM 分析结果
471
+
472
+ Example:
473
+ result = analyze_company_with_llm("Apple", "investment")
474
+ """
475
+ if not llm_client:
476
+ return {"error": "LLM client not available"}
477
+
478
+ if not MCP_DIRECT_AVAILABLE:
479
+ return {"error": "MCP functions not available"}
480
+
481
+ try:
482
+ # 先获取公司财务数据
483
+ company_data = get_company_summary_direct(company_input)
484
+ if company_data["status"] == "error":
485
+ return {"error": f"Failed to fetch company data: {company_data['errors']}"}
486
+
487
+ # 构建提示
488
+ data_str = json.dumps(company_data["data"], ensure_ascii=False, indent=2)
489
+
490
+ if analysis_type == "investment":
491
+ prompt = f"""
492
+ Based on the following company financial data, provide an investment recommendation:
493
+
494
+ {data_str}
495
+
496
+ Provide:
497
+ 1. Investment Recommendation (Buy/Hold/Sell)
498
+ 2. Key Strengths and Weaknesses
499
+ 3. Price Target Range
500
+ 4. Risk Assessment
501
+ 5. Risk Disclaimer
502
+ """
503
+ elif analysis_type == "risks":
504
+ prompt = f"""
505
+ Based on the following company data, analyze the key risks:
506
+
507
+ {data_str}
508
+
509
+ Identify:
510
+ 1. Financial Risks
511
+ 2. Market Risks
512
+ 3. Operational Risks
513
+ 4. Mitigation Strategies
514
+ 5. Risk Disclaimer
515
+ """
516
+ else: # summary
517
+ prompt = f"""
518
+ Provide a financial summary of the following company:
519
+
520
+ {data_str}
521
+
522
+ Include:
523
+ 1. Company Overview
524
+ 2. Financial Health
525
+ 3. Recent Performance
526
+ 4. Investment Outlook
527
+ """
528
+
529
+ # 调用 LLM
530
+ response = llm_client.chat.completions.create(
531
+ model="Qwen/Qwen2.5-72B-Instruct",
532
+ messages=[
533
+ {"role": "system", "content": get_system_prompt()},
534
+ {"role": "user", "content": prompt}
535
+ ],
536
+ max_tokens=1500,
537
+ temperature=0.7,
538
+ top_p=0.95,
539
+ stream=False
540
+ )
541
+
542
+ return {
543
+ "company": company_input,
544
+ "analysis_type": analysis_type,
545
+ "analysis": response.choices[0].message.content,
546
+ "data_used": company_data["data"]
547
+ }
548
+
549
+ except Exception as e:
550
+ return {"error": f"LLM analysis failed: {str(e)}"}
551
+
552
+
553
+ # ============================================================
554
+ # 便捷方法 - 获取单一时期财务数据
555
+ # ============================================================
556
+
557
+ def get_financial_data_direct(cik, period):
558
+ """
559
+ 获取指定时期的财务数据(直接调用)
560
+
561
+ Args:
562
+ cik: 公司 CIK 代码
563
+ period: 时期 (e.g., "2024", "2024Q3")
564
+
565
+ Returns:
566
+ 财务数据
567
+
568
+ Example:
569
+ result = get_financial_data_direct("0000320193", "2024")
570
+ """
571
+ if not MCP_DIRECT_AVAILABLE:
572
+ return {"error": "MCP functions not available"}
573
+
574
+ try:
575
+ return _get_financial_data(cik, period)
576
+ except Exception as e:
577
+ return {"error": str(e)}
578
+
579
+
580
+ # ============================================================
581
+ # 便捷方法 - 获取文件列表
582
+ # ============================================================
583
+
584
+ def get_company_filings_with_form_direct(cik, form_types=None):
585
+ """
586
+ 获取指定类型的公司 SEC 文件列表(直接调用)
587
+
588
+ Args:
589
+ cik: 公司 CIK 代码
590
+ form_types: 表单类型列表 (e.g., ["10-K", "10-Q"])
591
+
592
+ Returns:
593
+ 文件列表
594
+
595
+ Example:
596
+ result = get_company_filings_with_form_direct("0000320193", ["10-K"])
597
+ """
598
+ if not MCP_DIRECT_AVAILABLE:
599
+ return {"error": "MCP functions not available"}
600
+
601
+ try:
602
+ return _get_company_filings(cik, form_types)
603
+ except Exception as e:
604
+ return {"error": str(e)}
605
+
606
+
607
+ # ============================================================
608
+ # 便捷方法 - 轻量级查询
609
+ # ============================================================
610
+
611
+ def get_company_summary_direct(company_input):
612
+ """
613
+ 获取公司简要摘要信息(轻量级查询,仅搜索和基本信息)
614
+
615
+ Args:
616
+ company_input: 公司名称或代码
617
+
618
+ Returns:
619
+ 公司摘要数据
620
+
621
+ Example:
622
+ result = get_company_summary_direct("Apple")
623
+ """
624
+ from datetime import datetime
625
+
626
+ result = {
627
+ "timestamp": datetime.now().isoformat(),
628
+ "query_input": company_input,
629
+ "status": "success",
630
+ "data": {
631
+ "company_search": None,
632
+ "company_info": None
633
+ },
634
+ "errors": []
635
+ }
636
+
637
+ if not MCP_DIRECT_AVAILABLE:
638
+ result["status"] = "error"
639
+ result["errors"].append("MCP functions not available")
640
+ return result
641
+
642
+ try:
643
+ # 1. 搜索公司
644
+ search_result = search_company_direct(company_input)
645
+ if "error" in search_result:
646
+ result["errors"].append(f"Search error: {search_result['error']}")
647
+ result["status"] = "error"
648
+ return result
649
+
650
+ result["data"]["company_search"] = search_result
651
+
652
+ # 从搜索结果提取 CIK
653
+ cik = None
654
+ if isinstance(search_result, dict):
655
+ cik = search_result.get("cik")
656
+ elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
657
+ try:
658
+ first_item = search_result[0]
659
+ if isinstance(first_item, dict):
660
+ cik = first_item.get("cik")
661
+ except (IndexError, TypeError):
662
+ pass
663
+
664
+ if not cik:
665
+ result["errors"].append("Could not extract CIK from search result")
666
+ result["status"] = "error"
667
+ return result
668
+
669
+ # 2. 获取公司信息
670
+ company_info = get_company_info_direct(cik)
671
+ if "error" not in company_info:
672
+ result["data"]["company_info"] = company_info
673
+ else:
674
+ result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
675
+
676
+ except Exception as e:
677
+ result["status"] = "error"
678
+ result["errors"].append(f"Exception: {str(e)}")
679
+ import traceback
680
+ result["errors"].append(traceback.format_exc())
681
+
682
+ return result
683
+
684
+
685
+ def get_financial_metrics_only_direct(company_input, years=5):
686
+ """
687
+ 获取公司财务指标趋势(仅财务指标,不获取文件列表)
688
+
689
+ Args:
690
+ company_input: 公司名称或代码
691
+ years: 年数(默认 5 年)
692
+
693
+ Returns:
694
+ 财务指标数据
695
+
696
+ Example:
697
+ result = get_financial_metrics_only_direct("Apple", years=5)
698
+ """
699
+ from datetime import datetime
700
+
701
+ result = {
702
+ "timestamp": datetime.now().isoformat(),
703
+ "query_input": company_input,
704
+ "years": years,
705
+ "status": "success",
706
+ "data": None,
707
+ "errors": []
708
+ }
709
+
710
+ if not MCP_DIRECT_AVAILABLE:
711
+ result["status"] = "error"
712
+ result["errors"].append("MCP functions not available")
713
+ return result
714
+
715
+ try:
716
+ # 1. 搜索公司
717
+ search_result = search_company_direct(company_input)
718
+ if "error" in search_result:
719
+ result["errors"].append(f"Search error: {search_result['error']}")
720
+ result["status"] = "error"
721
+ return result
722
+
723
+ # 从搜索结果提取 CIK
724
+ cik = None
725
+ if isinstance(search_result, dict):
726
+ cik = search_result.get("cik")
727
+ elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
728
+ try:
729
+ first_item = search_result[0]
730
+ if isinstance(first_item, dict):
731
+ cik = first_item.get("cik")
732
+ except (IndexError, TypeError):
733
+ pass
734
+
735
+ if not cik:
736
+ result["errors"].append("Could not extract CIK from search result")
737
+ result["status"] = "error"
738
+ return result
739
+
740
+ # 2. 获取财务指标
741
+ metrics = extract_financial_metrics_direct(cik, years=years)
742
+ if "error" in metrics:
743
+ result["errors"].append(f"Failed to get metrics: {metrics['error']}")
744
+ result["status"] = "error"
745
+ else:
746
+ result["data"] = metrics
747
+
748
+ except Exception as e:
749
+ result["status"] = "error"
750
+ result["errors"].append(f"Exception: {str(e)}")
751
+ import traceback
752
+ result["errors"].append(traceback.format_exc())
753
+
754
+ return result
755
+
756
+
757
+ # ============================================================
758
+ # 测试函数
759
+ # ============================================================
760
+
761
+ if __name__ == "__main__":
762
+ print("\n" + "="*60)
763
+ print("Financial AI Assistant - Direct Method Test")
764
+ print("="*60)
765
+
766
+ # 测试 1: 公司搜索
767
+ print("\n1. 搜索公司 (Apple)...")
768
+ result = search_company_direct("Apple")
769
+ print(f" 结果: {result}")
770
+
771
+ # 测试 2: 公司摘要
772
+ print("\n2. 获取公司摘要信息 (Tesla)...")
773
+ summary = get_company_summary_direct("Tesla")
774
+ print(f" 状态: {summary['status']}")
775
+ print(f" 数据: {summary['data']}")
776
+ print(f" 错误: {summary['errors']}")
777
+
778
+ # 测试 3: 财务指标
779
+ print("\n3. 获取财务指标 (Microsoft)...")
780
+ metrics = get_financial_metrics_only_direct("Microsoft", years=3)
781
+ print(f" 状态: {metrics['status']}")
782
+ if metrics['status'] == 'success':
783
+ print(f" 指标数据: {metrics['data']}")
784
+ else:
785
+ print(f" 错误: {metrics['errors']}")
786
+
787
+ # 测试 4: 完整查询
788
+ print("\n4. 获取 Amazon 完整信息...")
789
+ full_query = query_company_direct("Amazon", get_filings=True, get_metrics=True)
790
+ print(f" 状态: {full_query['status']}")
791
+ print(f" 错误: {full_query['errors']}")
792
+
793
+ # 测试 5: LLM 分析 - 摘要
794
+ print("\n5. LLM 分析 - 公司摘要(Google)...")
795
+ if llm_client:
796
+ llm_result = analyze_company_with_llm("Google", "summary")
797
+ if "error" in llm_result:
798
+ print(f" 错误: {llm_result['error']}")
799
+ else:
800
+ print(f" 分析结果: {llm_result['analysis'][:200]}...")
801
+ else:
802
+ print(" LLM 客户端不可用")
803
+
804
+ # 测试 6: LLM 分析 - 投资建议
805
+ print("\n6. LLM 分析 - 投资建议(NVIDIA)...")
806
+ if llm_client:
807
+ llm_result = analyze_company_with_llm("NVIDIA", "investment")
808
+ if "error" in llm_result:
809
+ print(f" 错误: {llm_result['error']}")
810
+ else:
811
+ print(f" 分析结果: {llm_result['analysis'][:200]}...")
812
+ else:
813
+ print(" LLM 客户端不可用")
814
+
815
+ # 测试 7: LLM 分析 - 风险评估
816
+ print("\n7. LLM 分析 - 风险评估(Meta)...")
817
+ if llm_client:
818
+ llm_result = analyze_company_with_llm("Meta", "risks")
819
+ if "error" in llm_result:
820
+ print(f" 错误: {llm_result['error']}")
821
+ else:
822
+ print(f" 分析结果: {llm_result['analysis'][:200]}...")
823
+ else:
824
+ print(" LLM 客户端不可用")
825
+
826
+ print("\n" + "="*60)
827
+
828
+
829
+ # ============================================================
830
+ # 完整对话引擎 - chatbot_response
831
+ # ============================================================
832
+
833
+ # Token 限制配置
834
+ MAX_TOTAL_TOKENS = 6000
835
+ MAX_TOOL_RESULT_CHARS = 1500
836
+ MAX_HISTORY_CHARS = 500
837
+ MAX_HISTORY_TURNS = 2
838
+ MAX_TOOL_ITERATIONS = 5 # ✅ 限制最多调用5个工具,确保选择最合适的工具
839
+ MAX_OUTPUT_TOKENS = 2000
840
+
841
+ # MCP 工具配置 - 包含财务数据和市场新闻工具
842
+ MCP_TOOLS = [
843
+ # 财务数据工具 (EasyReportDataMCP)
844
+ {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies by name, ticker, or CIK. Returns company information including CIK, name, tickers, and industry classification.", "parameters": {"type": "object", "properties": {"company_input": {"type": "string", "description": "Company name (e.g., 'Tesla'), ticker symbol (e.g., 'TSLA'), or CIK code (e.g., '0001318605')"}}, "required": ["company_input"]}}},
845
+ {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get the most recent financial data for a company including revenue, net income, EPS, operating expenses, and cash flow.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format, e.g., '0001318605')"}}, "required": ["cik"]}}},
846
+ {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Extract multi-year financial metrics trends showing historical performance over specified years.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format)"}, "years": {"type": "integer", "description": "Number of years of data to retrieve (e.g., 3 or 5)", "default": 3}}, "required": ["cik", "years"]}}},
847
+
848
+ # 市场和新闻工具 (MarketandStockMCP)
849
+ {"type": "function", "function": {"name": "get_quote", "description": "Get real-time stock quote data including current price, daily change, high/low, and previous close. Use when users ask about current stock prices or market performance.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Stock ticker symbol (e.g., 'AAPL', 'TSLA', 'MSFT')"}}, "required": ["symbol"]}}},
850
+ {"type": "function", "function": {"name": "get_market_news", "description": "Get latest market news by category. Use when users ask about general market trends, forex, crypto, or M&A news.", "parameters": {"type": "object", "properties": {"category": {"type": "string", "enum": ["general", "forex", "crypto", "merger"], "description": "News category: general (stocks/economy), forex (currency), crypto (cryptocurrency), merger (M&A)", "default": "general"}, "min_id": {"type": "integer", "description": "Minimum news ID for pagination (default: 0)", "default": 0}}, "required": ["category"]}}},
851
+ {"type": "function", "function": {"name": "get_company_news", "description": "Get company-specific news and announcements. Only available for North American companies. Use when users ask about specific company news.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Company stock ticker symbol (e.g., 'AAPL', 'TSLA')"}, "from_date": {"type": "string", "description": "Start date in YYYY-MM-DD format (optional, defaults to 7 days ago)"}, "to_date": {"type": "string", "description": "End date in YYYY-MM-DD format (optional, defaults to today)"}}, "required": ["symbol"]}}}
852
+ ]
853
+
854
+
855
+ def truncate_text(text, max_chars, suffix="...[truncated]"):
856
+ """截断文本到指定长度"""
857
+ text = str(text)
858
+ if len(text) <= max_chars:
859
+ return text
860
+ return text[:max_chars] + suffix
861
+
862
+
863
+ def call_mcp_tool(tool_name, arguments):
864
+ """直接调用 MCP 工具函数(不通过HTTP)"""
865
+ try:
866
+ # ✅ 财务数据工具 - 直接调用 Python 函数
867
+ if tool_name == "advanced_search_company":
868
+ company_input = arguments.get("company_input", "")
869
+ return _advanced_search_company(company_input)
870
+
871
+ elif tool_name == "get_latest_financial_data":
872
+ cik = arguments.get("cik", "")
873
+ return _get_latest_financial_data(cik)
874
+
875
+ elif tool_name == "extract_financial_metrics":
876
+ cik = arguments.get("cik", "")
877
+ years = arguments.get("years", 3)
878
+ return _extract_financial_metrics(cik, years)
879
+
880
+ # ✅ 市场和新闻工具 - 直接调用 Python 函数
881
+ elif tool_name == "get_quote":
882
+ from MarketandStockMCP.news_quote_mcp import get_quote
883
+ symbol = arguments.get("symbol", "")
884
+ return get_quote(symbol)
885
+
886
+ elif tool_name == "get_market_news":
887
+ from MarketandStockMCP.news_quote_mcp import get_market_news
888
+ category = arguments.get("category", "general")
889
+ min_id = arguments.get("min_id", 0)
890
+ return get_market_news(category, min_id)
891
+
892
+ elif tool_name == "get_company_news":
893
+ from MarketandStockMCP.news_quote_mcp import get_company_news
894
+ symbol = arguments.get("symbol", "")
895
+ from_date = arguments.get("from_date")
896
+ to_date = arguments.get("to_date")
897
+ return get_company_news(symbol, from_date, to_date)
898
+
899
+ else:
900
+ return {"error": f"Unknown tool: {tool_name}"}
901
+
902
+ except Exception as e:
903
+ import traceback
904
+ return {
905
+ "error": f"{str(e)}",
906
+ "traceback": traceback.format_exc()[:500]
907
+ }
908
+
909
+
910
+ def chatbot_response(message, history=None):
911
+ """
912
+ AI 助手主函数(完整对话引擎)
913
+ 支持多轮对话、动态工具调用、流式输出
914
+
915
+ Args:
916
+ message: 用户消息
917
+ history: 对话历史,格式: [(user_msg, assistant_msg), ...]
918
+
919
+ Returns:
920
+ 生成器,不断 yield 响应文本
921
+
922
+ Example:
923
+ for response in chatbot_response("What's Apple's revenue?", []):
924
+ print(response)
925
+ """
926
+ if not llm_client:
927
+ yield "❌ Error: LLM client not available"
928
+ return
929
+
930
+ if not MCP_DIRECT_AVAILABLE:
931
+ yield "❌ Error: MCP functions not available"
932
+ return
933
+
934
+ try:
935
+ messages = [{"role": "system", "content": get_system_prompt()}]
936
+
937
+ # 添加历史(最近2轮) - 严格限制上下文长度
938
+ if history:
939
+ for item in history[-MAX_HISTORY_TURNS:]:
940
+ if isinstance(item, (list, tuple)) and len(item) == 2:
941
+ messages.append({"role": "user", "content": item[0]})
942
+ assistant_msg = str(item[1])
943
+ if len(assistant_msg) > MAX_HISTORY_CHARS:
944
+ assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
945
+ messages.append({"role": "assistant", "content": assistant_msg})
946
+
947
+ messages.append({"role": "user", "content": message})
948
+
949
+ tool_calls_log = []
950
+ final_response_content = None
951
+
952
+ # LLM 调用循环(支持多轮工具调用)
953
+ for iteration in range(MAX_TOOL_ITERATIONS):
954
+ response = llm_client.chat.completions.create(
955
+ model="Qwen/Qwen2.5-72B-Instruct",
956
+ messages=messages,
957
+ tools=MCP_TOOLS, # type: ignore
958
+ max_tokens=MAX_OUTPUT_TOKENS,
959
+ temperature=0.7,
960
+ tool_choice="auto",
961
+ stream=False
962
+ )
963
+
964
+ choice = response.choices[0]
965
+
966
+ if choice.message.tool_calls:
967
+ messages.append(choice.message)
968
+
969
+ for tool_call in choice.message.tool_calls:
970
+ tool_name = tool_call.function.name
971
+ try:
972
+ tool_args = json.loads(tool_call.function.arguments)
973
+ except json.JSONDecodeError:
974
+ tool_args = {}
975
+
976
+ tool_result = call_mcp_tool(tool_name, tool_args)
977
+
978
+ if isinstance(tool_result, dict) and "error" in tool_result:
979
+ tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
980
+ result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
981
+ else:
982
+ result_str = json.dumps(tool_result, ensure_ascii=False)
983
+
984
+ if len(result_str) > MAX_TOOL_RESULT_CHARS:
985
+ if isinstance(tool_result, dict) and "text" in tool_result:
986
+ truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
987
+ tool_result_truncated = {"text": truncated_text, "_truncated": True}
988
+ elif isinstance(tool_result, dict):
989
+ truncated = {}
990
+ char_count = 0
991
+ for k, v in list(tool_result.items())[:8]:
992
+ v_str = str(v)[:300]
993
+ truncated[k] = v_str
994
+ char_count += len(k) + len(v_str)
995
+ if char_count > MAX_TOOL_RESULT_CHARS:
996
+ break
997
+ tool_result_truncated = {**truncated, "_truncated": True}
998
+ else:
999
+ tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
1000
+ result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
1001
+ else:
1002
+ result_for_llm = result_str
1003
+
1004
+ tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
1005
+
1006
+ messages.append({
1007
+ "role": "tool",
1008
+ "name": tool_name,
1009
+ "content": result_for_llm,
1010
+ "tool_call_id": tool_call.id
1011
+ })
1012
+
1013
+ continue
1014
+ else:
1015
+ final_response_content = choice.message.content
1016
+ break
1017
+
1018
+ response_prefix = ""
1019
+
1020
+ if tool_calls_log:
1021
+ # ✅ 简洁显示工具调用次数,不显示警告
1022
+ tool_count = len(tool_calls_log)
1023
+
1024
+ response_prefix += f"""<div style='margin-bottom: 15px;'>
1025
+ <div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'>
1026
+ 🛠️ Tools Used ({tool_count}/{MAX_TOOL_ITERATIONS} calls)
1027
+ </div>
1028
+ """
1029
+
1030
+ for idx, tool_call in enumerate(tool_calls_log):
1031
+ args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
1032
+ result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
1033
+ result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
1034
+ error_indicator = " ❌ Error" if tool_call.get('error') else ""
1035
+
1036
+ response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
1037
+ <summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
1038
+ <strong style='color: #2c5aa0;'>📋 {idx+1}. {tool_call['name']}{error_indicator}</strong>
1039
+ </summary>
1040
+ <div style='background: #f9f9f9; padding: 12px;'>
1041
+ <pre style='background: #fff; padding: 10px; overflow-x: auto; font-size: 0.85em;'>{result_preview}</pre>
1042
+ </div>
1043
+ </details>
1044
+ """
1045
+
1046
+ response_prefix += """</div>\n---\n"""
1047
+
1048
+ yield response_prefix
1049
+
1050
+ if final_response_content:
1051
+ yield response_prefix + final_response_content
1052
+ else:
1053
+ try:
1054
+ stream = llm_client.chat.completions.create(
1055
+ model="Qwen/Qwen2.5-72B-Instruct",
1056
+ messages=messages,
1057
+ tools=None,
1058
+ max_tokens=MAX_OUTPUT_TOKENS,
1059
+ temperature=0.7,
1060
+ stream=True
1061
+ )
1062
+
1063
+ accumulated_text = ""
1064
+ for chunk in stream:
1065
+ if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
1066
+ accumulated_text += chunk.choices[0].delta.content
1067
+ yield response_prefix + accumulated_text
1068
+ except Exception:
1069
+ final_resp = llm_client.chat.completions.create(
1070
+ model="Qwen/Qwen2.5-72B-Instruct",
1071
+ messages=messages,
1072
+ tools=None,
1073
+ max_tokens=MAX_OUTPUT_TOKENS,
1074
+ temperature=0.7,
1075
+ stream=False
1076
+ )
1077
+ yield response_prefix + (final_resp.choices[0].message.content or "")
1078
+
1079
+ except Exception as e:
1080
+ import traceback
1081
+ error_detail = str(e)
1082
+ if "500" in error_detail:
1083
+ yield f"❌ Error: 模型服务器错误\n\n{error_detail[:200]}"
1084
+ else:
1085
+ yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"