Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| import time | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| from huggingface_hub import InferenceClient | |
| MCP_SPACE = "JC321/EasyReportDateMCP" | |
| MCP_URL = "https://jc321-easyreportdatemcp.hf.space" | |
| MCP_ENDPOINT = "/mcp" # MCP 工具调用端点 | |
| # 设置请求头 | |
| HEADERS = { | |
| "Content-Type": "application/json", | |
| "User-Agent": "SEC-Query-Assistant/1.0 (jtyxabc@gmail.com)" | |
| } | |
| # 创建带重试的 requests session | |
| def create_session_with_retry(): | |
| """创建带重试机制的 requests session""" | |
| session = requests.Session() | |
| retry = Retry( | |
| total=3, # 最多重试3次 | |
| backoff_factor=1, # 重试间隔:1秒, 2秒, 4秒 | |
| status_forcelist=[500, 502, 503, 504], # 这些状态码会触发重试 | |
| ) | |
| adapter = HTTPAdapter(max_retries=retry) | |
| session.mount('http://', adapter) | |
| session.mount('https://', adapter) | |
| return session | |
| # 创建全局 session | |
| session = create_session_with_retry() | |
| # 初始化 Hugging Face Inference Client | |
| # 使用环境变量或者免费的公开模型 | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # 可选:如果需要访问私有模型 | |
| try: | |
| client = InferenceClient(token=HF_TOKEN) | |
| except Exception as e: | |
| print(f"Warning: Failed to initialize Hugging Face client: {e}") | |
| client = None | |
| # 定义可用的 MCP 工具 | |
| MCP_TOOLS = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "advanced_search_company", | |
| "description": "Search for a US listed company by name or stock ticker symbol to get basic company information including CIK, name, and ticker", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "company_input": { | |
| "type": "string", | |
| "description": "Company name or stock ticker symbol (e.g., 'Apple', 'AAPL', 'Microsoft', 'TSLA')" | |
| } | |
| }, | |
| "required": ["company_input"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_latest_financial_data", | |
| "description": "Get the latest financial data for a company using its CIK number. Returns revenue, net income, EPS, operating expenses, and cash flow for the most recent fiscal period", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "cik": { | |
| "type": "string", | |
| "description": "10-digit CIK number of the company (must be obtained from advanced_search_company first)" | |
| } | |
| }, | |
| "required": ["cik"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "extract_financial_metrics", | |
| "description": "Extract financial metrics trends over multiple years for a company. Returns historical data including revenue, net income, EPS, operating expenses, and cash flow", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "cik": { | |
| "type": "string", | |
| "description": "10-digit CIK number of the company" | |
| }, | |
| "years": { | |
| "type": "integer", | |
| "description": "Number of years to retrieve (typically 3 or 5)", | |
| "enum": [3, 5] | |
| } | |
| }, | |
| "required": ["cik", "years"] | |
| } | |
| } | |
| } | |
| ] | |
| # 格式化数值显示 | |
| def format_value(value, value_type="money"): | |
| """ | |
| 格式化数值:0或极小值显示为N/A,其他显示为带单位的格式 | |
| value_type: "money" (金额), "eps" (每股收益), "number" (普通数字) | |
| """ | |
| # 检查 None 或极小值(阈值设为0.01,即10M,低于此值视为无意义数据) | |
| if value is None or abs(value) < 0.01: | |
| return "N/A" | |
| if value_type == "money": | |
| return f"${value:.2f}B" | |
| elif value_type == "eps": | |
| return f"${value:.2f}" | |
| else: # number | |
| return f"{value:.2f}" | |
| def call_mcp_tool(tool_name, arguments): | |
| """调用 MCP 工具并返回结果""" | |
| try: | |
| # 构建完整的 URL | |
| full_url = f"{MCP_URL}{MCP_ENDPOINT}" | |
| # FastMCP HTTP Server 使用 /mcp 端点 | |
| response = session.post( | |
| full_url, | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": tool_name, | |
| "arguments": arguments | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=60 | |
| ) | |
| # 调试信息 | |
| print(f"DEBUG: Calling {full_url}") | |
| print(f"DEBUG: Tool: {tool_name}, Args: {arguments}") | |
| print(f"DEBUG: Status Code: {response.status_code}") | |
| print(f"DEBUG: Response: {response.text[:500]}") | |
| if response.status_code != 200: | |
| return { | |
| "error": f"HTTP {response.status_code}", | |
| "detail": response.text, | |
| "url": full_url | |
| } | |
| return response.json() | |
| except Exception as e: | |
| return { | |
| "error": str(e), | |
| "url": full_url if 'full_url' in locals() else MCP_URL | |
| } | |
| def normalize_cik(cik): | |
| """ | |
| 格式化 CIK 为标准的 10 位格式 | |
| """ | |
| if not cik: | |
| return None | |
| # 转换为字符串并移除非数字字符 | |
| cik_str = str(cik).replace('-', '').replace(' ', '') | |
| # 仅保留数字 | |
| cik_str = ''.join(c for c in cik_str if c.isdigit()) | |
| # 填充前导 0 至 10 位 | |
| return cik_str.zfill(10) if cik_str else None | |
| def parse_mcp_response(response_data): | |
| """ | |
| 解析 MCP 协议响应数据 | |
| 支持格式: | |
| 1. {"result": {"content": [{"type": "text", "text": "{...}"}]}} | |
| 2. {"content": [{"type": "text", "text": "{...}"}]} | |
| 3. 直接的 JSON 数据 | |
| """ | |
| if not isinstance(response_data, dict): | |
| return response_data | |
| # 格式 1: {"result": {"content": [...]}} | |
| if "result" in response_data and "content" in response_data["result"]: | |
| content = response_data["result"]["content"] | |
| if content and len(content) > 0: | |
| text_content = content[0].get("text", "{}") | |
| # 直接解析 JSON(MCP Server 已移除 emoji 前缀) | |
| try: | |
| return json.loads(text_content) | |
| except json.JSONDecodeError: | |
| return text_content | |
| return {} | |
| # 格式 2: {"content": [...]} | |
| elif "content" in response_data: | |
| content = response_data.get("content", []) | |
| if content and len(content) > 0: | |
| text_content = content[0].get("text", "{}") | |
| # 直接解析 JSON | |
| try: | |
| return json.loads(text_content) | |
| except json.JSONDecodeError: | |
| return text_content | |
| return {} | |
| # 格式 3: 直接返回 | |
| return response_data | |
| # MCP 工具定义 | |
| def create_mcp_tools(): | |
| """创建 MCP 工具列表""" | |
| return [ | |
| { | |
| "name": "query_financial_data", | |
| "description": "Query SEC financial data for US listed companies", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "company_name": { | |
| "type": "string", | |
| "description": "Company name or stock symbol (e.g., Apple, NVIDIA, AAPL)" | |
| }, | |
| "query_type": { | |
| "type": "string", | |
| "enum": ["Latest Financial Data", "3-Year Trends", "5-Year Trends"], | |
| "description": "Type of financial query" | |
| } | |
| }, | |
| "required": ["company_name", "query_type"] | |
| } | |
| } | |
| ] | |
| # 工具执行函数 | |
| def execute_tool(tool_name, **kwargs): | |
| """执行 MCP 工具""" | |
| if tool_name == "query_financial_data": | |
| return query_financial_data(kwargs.get("company_name"), kwargs.get("query_type")) | |
| return f"Unknown tool: {tool_name}" | |
| # 创建超链接 | |
| def create_source_link(source_form, source_url=None): | |
| """为Source Form创建超链接,使用MCP后端返回的URL""" | |
| if not source_form or source_form == 'N/A': | |
| return source_form | |
| # 如果后端提供了URL,使用后端的URL | |
| if source_url and source_url != 'N/A': | |
| return f"[{source_form}]({source_url})" | |
| # 如果没有URL,只显示文本 | |
| return source_form | |
| def query_financial_data(company_name, query_type): | |
| """查询财务数据的主函数""" | |
| if not company_name: | |
| return "Please enter a company name or stock symbol" | |
| # 翻译英文查询类型为中文(用于后端处理) | |
| query_type_mapping = { | |
| "Latest Financial Data": "最新财务数据", | |
| "3-Year Trends": "3年趋势", | |
| "5-Year Trends": "5年趋势", | |
| "Company Filings": "公司报表列表" | |
| } | |
| internal_query_type = query_type_mapping.get(query_type, query_type) | |
| try: | |
| # 使用 MCP 协议调用工具 | |
| # 先搜索公司(使用 advanced_search_company) | |
| try: | |
| search_result = call_mcp_tool("advanced_search_company", {"company_input": company_name}) | |
| except requests.exceptions.Timeout: | |
| return f"❌ MCP Server Timeout: The server took too long to respond (>60s).\n\n**Possible reasons**:\n1. MCP Server is cold starting (first request after idle)\n2. Server is overloaded\n3. Network issues\n\n**Suggestion**: Please try again in a few moments. If the problem persists, the MCP Server at {MCP_URL} may be down." | |
| # 检查是否有错误 | |
| if "error" in search_result: | |
| return f"❌ Server Error: {search_result.get('error')}\n\nResponse: {search_result.get('detail', 'N/A')}\n\nURL: {search_result.get('url', MCP_URL)}" | |
| # 解析搜索结果 | |
| company = parse_mcp_response(search_result) | |
| if isinstance(company, dict) and company.get("error"): | |
| return f"❌ Error: {company['error']}" | |
| # advanced_search 返回的字段: cik, name, ticker | |
| # 注意: 不是 tickers 和 sic_description | |
| company_name = company.get('name', 'Unknown') | |
| ticker = company.get('ticker', 'N/A') | |
| result = f"# {company_name}\n\n" | |
| result += f"**Stock Symbol**: {ticker}\n" | |
| # sic_description 需要后续通过 get_company_info 获取,这里暂时不显示 | |
| result += "\n---\n\n" | |
| # 获取并格式化 CIK 为 10 位标准格式 | |
| cik = normalize_cik(company.get('cik')) | |
| if not cik: | |
| return result + f"❌ Error: Invalid CIK from company search\n\nDebug: company data = {json.dumps(company, indent=2)}" | |
| # 根据查询类型获取数据 | |
| if internal_query_type == "最新财务数据": | |
| data_resp = session.post( | |
| f"{MCP_URL}/mcp", | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": "get_latest_financial_data", | |
| "arguments": {"cik": cik} | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=60 # 增加到60秒 | |
| ) | |
| if data_resp.status_code != 200: | |
| return result + f"❌ Server Error: HTTP {data_resp.status_code}\n\n{data_resp.text[:500]}" | |
| try: | |
| data_result = data_resp.json() | |
| # 使用统一的 MCP 响应解析函数 | |
| data = parse_mcp_response(data_result) | |
| except (ValueError, KeyError, json.JSONDecodeError) as e: | |
| return result + f"❌ JSON Parse Error: {str(e)}\n\n{data_resp.text[:500]}" | |
| if isinstance(data, dict) and data.get("error"): | |
| return result + f"❌ {data['error']}" | |
| cik = data.get('cik') | |
| result += f"## Fiscal Year {data.get('period', 'N/A')}\n\n" | |
| total_revenue = data.get('total_revenue', 0) / 1e9 if data.get('total_revenue') else 0 | |
| net_income = data.get('net_income', 0) / 1e9 if data.get('net_income') else 0 | |
| eps = data.get('earnings_per_share', 0) if data.get('earnings_per_share') else 0 | |
| opex = data.get('operating_expenses', 0) / 1e9 if data.get('operating_expenses') else 0 | |
| ocf = data.get('operating_cash_flow', 0) / 1e9 if data.get('operating_cash_flow') else 0 | |
| result += f"- **Total Revenue**: {format_value(total_revenue)}\n" | |
| result += f"- **Net Income**: {format_value(net_income)}\n" | |
| result += f"- **Earnings Per Share**: {format_value(eps, 'eps')}\n" | |
| result += f"- **Operating Expenses**: {format_value(opex)}\n" | |
| result += f"- **Operating Cash Flow**: {format_value(ocf)}\n" | |
| # 使用后端返回的 source_url | |
| source_form = data.get('source_form', 'N/A') | |
| source_url = data.get('source_url', None) # 从后端获取URL | |
| result += f"- **Source Form**: {create_source_link(source_form, source_url)}\n" | |
| elif internal_query_type == "3年趋势": | |
| metrics_resp = session.post( | |
| f"{MCP_URL}/mcp", | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": "extract_financial_metrics", | |
| "arguments": {"cik": cik, "years": 3} | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=120 # 3年趋势需要更长时间,增加到120秒 | |
| ) | |
| if metrics_resp.status_code != 200: | |
| return result + f"❌ Server Error: HTTP {metrics_resp.status_code}\n\n{metrics_resp.text[:500]}" | |
| try: | |
| metrics_result = metrics_resp.json() | |
| # 使用统一的 MCP 响应解析函数 | |
| metrics = parse_mcp_response(metrics_result) | |
| except (ValueError, KeyError, json.JSONDecodeError) as e: | |
| return result + f"❌ JSON Parse Error: {str(e)}\n\nResponse: {metrics_resp.text[:500]}" | |
| if isinstance(metrics, dict) and metrics.get("error"): | |
| return result + f"❌ {metrics['error']}" | |
| result += f"## 3-Year Financial Trends ({metrics.get('periods', 0)} periods)\n\n" | |
| # 显示所有数据(包括年度和季度) | |
| all_data = metrics.get('data', []) # MCP Server 返回的字段是 'data' | |
| # 去重:根据period和source_form去重 | |
| seen = set() | |
| unique_data = [] | |
| for m in all_data: | |
| key = (m.get('period', 'N/A'), m.get('source_form', 'N/A')) | |
| if key not in seen: | |
| seen.add(key) | |
| unique_data.append(m) | |
| # 按期间降序排序,确保显示最近的3年数据 | |
| # 使用更智能的排序:先按年份,再按是否是季度 | |
| # 正确顺序:FY2024 → 2024Q3 → 2024Q2 → 2024Q1 → FY2023 | |
| def sort_key(x): | |
| period = x.get('period', '0000') | |
| # 提取年份(前4位) | |
| year = period[:4] if len(period) >= 4 else '0000' | |
| # 如果有Q,提取季度号 | |
| if 'Q' in period: | |
| quarter = period[period.index('Q')+1] if period.index('Q')+1 < len(period) else '0' | |
| return (year, 1, 4 - int(quarter)) # Q在FY后面:Q3, Q2, Q1 (4-3=1, 4-2=2, 4-1=3) | |
| else: | |
| return (year, 0, 0) # FY 排在同年的所有Q之前 | |
| unique_data = sorted(unique_data, key=sort_key, reverse=True) | |
| result += "| Period | Revenue (B) | Net Income (B) | EPS | Operating Expenses (B) | Operating Cash Flow (B) | Source Form |\n" | |
| result += "|--------|-------------|----------------|-----|------------------------|-------------------------|-------------|\n" | |
| for m in unique_data: | |
| period = m.get('period', 'N/A') | |
| rev = (m.get('total_revenue') or 0) / 1e9 | |
| inc = (m.get('net_income') or 0) / 1e9 | |
| eps_val = m.get('earnings_per_share') or 0 | |
| opex = (m.get('operating_expenses') or 0) / 1e9 | |
| ocf = (m.get('operating_cash_flow') or 0) / 1e9 | |
| source_form = m.get('source_form', 'N/A') | |
| source_url = m.get('source_url', None) # 从后端获取URL | |
| # 区分年度和季度,修复双重FY前缀问题 | |
| if 'Q' in period: | |
| # 季度数据,不添加前缀 | |
| display_period = period | |
| else: | |
| # 年度数据,只在没有FY的情况下添加 | |
| display_period = period if period.startswith('FY') else f"FY{period}" | |
| source_link = create_source_link(source_form, source_url) | |
| result += f"| {display_period} | {format_value(rev)} | {format_value(inc)} | {format_value(eps_val, 'eps')} | {format_value(opex)} | {format_value(ocf)} | {source_link} |\n" | |
| elif internal_query_type == "5年趋势": | |
| metrics_resp = session.post( | |
| f"{MCP_URL}/mcp", | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": "extract_financial_metrics", | |
| "arguments": {"cik": cik, "years": 5} | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=180 # 5年趋势需要更长时间,增加到180秒 | |
| ) | |
| if metrics_resp.status_code != 200: | |
| return result + f"❌ Server Error: HTTP {metrics_resp.status_code}\n\n{metrics_resp.text[:500]}" | |
| try: | |
| metrics_result = metrics_resp.json() | |
| # 使用统一的 MCP 响应解析函数 | |
| metrics = parse_mcp_response(metrics_result) | |
| except (ValueError, KeyError, json.JSONDecodeError) as e: | |
| return result + f"❌ JSON Parse Error: {str(e)}\n\nResponse: {metrics_resp.text[:500]}" | |
| if isinstance(metrics, dict) and metrics.get("error"): | |
| return result + f"❌ {metrics['error']}" | |
| # 显示所有数据(包括年度和季度) | |
| all_data = metrics.get('data', []) # MCP Server 返回的字段是 'data' | |
| # 去重:根据period和source_form去重 | |
| seen = set() | |
| unique_data = [] | |
| for m in all_data: | |
| key = (m.get('period', 'N/A'), m.get('source_form', 'N/A')) | |
| if key not in seen: | |
| seen.add(key) | |
| unique_data.append(m) | |
| # 按期间降序排序,确保显示最近的5年数据 | |
| # 使用更智能的排序:先按年份,再按是否是季度 | |
| # 正确顺序:FY2024 → 2024Q3 → 2024Q2 → 2024Q1 → FY2023 | |
| def sort_key(x): | |
| period = x.get('period', '0000') | |
| # 提取年份(前4位) | |
| year = period[:4] if len(period) >= 4 else '0000' | |
| # 如果有Q,提取季度号 | |
| if 'Q' in period: | |
| quarter = period[period.index('Q')+1] if period.index('Q')+1 < len(period) else '0' | |
| return (year, 1, 4 - int(quarter)) # Q在FY后面:Q3, Q2, Q1 (4-3=1, 4-2=2, 4-1=3) | |
| else: | |
| return (year, 0, 0) # FY 排在同年的所有Q之前 | |
| unique_data = sorted(unique_data, key=sort_key, reverse=True) | |
| result += f"## 5-Year Financial Trends ({metrics.get('periods', 0)} periods)\n\n" | |
| result += "| Period | Revenue (B) | Net Income (B) | EPS | Operating Expenses (B) | Operating Cash Flow (B) | Source Form |\n" | |
| result += "|--------|-------------|----------------|-----|------------------------|-------------------------|-------------|\n" | |
| for m in unique_data: | |
| period = m.get('period', 'N/A') | |
| rev = (m.get('total_revenue') or 0) / 1e9 | |
| inc = (m.get('net_income') or 0) / 1e9 | |
| eps_val = m.get('earnings_per_share') or 0 | |
| opex = (m.get('operating_expenses') or 0) / 1e9 | |
| ocf = (m.get('operating_cash_flow') or 0) / 1e9 | |
| source_form = m.get('source_form', 'N/A') | |
| source_url = m.get('source_url', None) # 从后端获取URL | |
| # 区分年度和季度,修复双重FY前缀问题 | |
| if 'Q' in period: | |
| # 季度数据,不添加前缀 | |
| display_period = period | |
| else: | |
| # 年度数据,只在没有FY的情况下添加 | |
| display_period = period if period.startswith('FY') else f"FY{period}" | |
| source_link = create_source_link(source_form, source_url) | |
| result += f"| {display_period} | {format_value(rev)} | {format_value(inc)} | {format_value(eps_val, 'eps')} | {format_value(opex)} | {format_value(ocf)} | {source_link} |\n" | |
| elif internal_query_type == "公司报表列表": | |
| # 查询公司所有报表 | |
| filings_resp = session.post( | |
| f"{MCP_URL}/mcp", | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": "get_company_filings", | |
| "arguments": {"cik": cik, "limit": 50} | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=90 # 增加到90秒 | |
| ) | |
| if filings_resp.status_code != 200: | |
| return result + f"❌ Server Error: HTTP {filings_resp.status_code}\n\n{filings_resp.text[:500]}" | |
| try: | |
| filings_result = filings_resp.json() | |
| # 使用统一的 MCP 响应解析函数 | |
| filings_data = parse_mcp_response(filings_result) | |
| except (ValueError, KeyError, json.JSONDecodeError) as e: | |
| return result + f"❌ JSON Parse Error: {str(e)}\n\n{filings_resp.text[:500]}" | |
| if isinstance(filings_data, dict) and filings_data.get("error"): | |
| return result + f"❌ {filings_data['error']}" | |
| filings = filings_data.get('filings', []) if isinstance(filings_data, dict) else filings_data | |
| result += f"## Company Filings ({len(filings)} records)\n\n" | |
| result += "| Form Type | Filing Date | Accession Number | Primary Document |\n" | |
| result += "|-----------|-------------|------------------|------------------|\n" | |
| for filing in filings: | |
| form_type = filing.get('form_type', 'N/A') | |
| filing_date = filing.get('filing_date', 'N/A') | |
| accession_num = filing.get('accession_number', 'N/A') | |
| primary_doc = filing.get('primary_document', 'N/A') | |
| filing_url = filing.get('filing_url', None) # 从后端获取URL | |
| # 使用后端返回的URL创建链接 | |
| if filing_url and filing_url != 'N/A': | |
| form_link = f"[{form_type}]({filing_url})" | |
| primary_doc_link = f"[{primary_doc}]({filing_url})" | |
| else: | |
| form_link = form_type | |
| primary_doc_link = primary_doc | |
| result += f"| {form_link} | {filing_date} | {accession_num} | {primary_doc_link} |\n" | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| return f"❌ Network Error: {str(e)}\n\nMCP Server: {MCP_URL}" | |
| except Exception as e: | |
| import traceback | |
| return f"❌ Unexpected Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| # 调用 MCP 工具的实际执行函数 | |
| def call_mcp_tool(tool_name, arguments): | |
| """调用 MCP 工具并返回结果""" | |
| try: | |
| # FastMCP HTTP Server 使用 /mcp 端点 | |
| response = session.post( | |
| f"{MCP_URL}/mcp", | |
| json={ | |
| "jsonrpc": "2.0", | |
| "method": "tools/call", | |
| "params": { | |
| "name": tool_name, | |
| "arguments": arguments | |
| }, | |
| "id": 1 | |
| }, | |
| headers=HEADERS, | |
| timeout=60 | |
| ) | |
| if response.status_code != 200: | |
| return {"error": f"HTTP {response.status_code}: {response.text[:200]}"} | |
| result = response.json() | |
| return parse_mcp_response(result) | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Chatbot 功能:使用 LLM + MCP 工具 | |
| def chatbot_response(message, history): | |
| """智能聊天机器人,集成 LLM 和 MCP 工具""" | |
| try: | |
| # 构建对话历史 | |
| messages = [] | |
| # 系统提示词 | |
| system_prompt = """You are a helpful financial data assistant with access to SEC EDGAR data through specialized tools. | |
| You can help users with: | |
| - General questions and conversations about any topic | |
| - Financial data queries for US listed companies | |
| - Company information and stock data analysis | |
| When users ask about financial data, company information, or stock performance, you should use the available tools to retrieve accurate, real-time data from SEC EDGAR filings. | |
| Available tools: | |
| 1. advanced_search_company: Search for company information by name or ticker | |
| 2. get_latest_financial_data: Get the latest financial metrics for a company | |
| 3. extract_financial_metrics: Get historical financial trends (3 or 5 years) | |
| Always be helpful, accurate, and cite the data sources when providing financial information.""" | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # 添加历史对话(最近 5 轮) | |
| # Gradio 6.x 的 history 格式可能是 [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}] | |
| # 或者是 [(user_msg, assistant_msg), ...] 的元组列表 | |
| if history: | |
| for item in history[-5:]: | |
| if isinstance(item, dict): | |
| # 新格式:字典列表 | |
| messages.append(item) | |
| elif isinstance(item, (list, tuple)) and len(item) == 2: | |
| # 旧格式:元组列表 | |
| user_msg, assistant_msg = item | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # 添加当前消息 | |
| messages.append({"role": "user", "content": message}) | |
| # 调用 LLM,启用工具调用 | |
| response_text = "" | |
| tool_calls_log = [] | |
| max_iterations = 5 # 防止无限循环 | |
| iteration = 0 | |
| while iteration < max_iterations: | |
| iteration += 1 | |
| # 使用支持工具调用的模型(如 Qwen, Llama 等) | |
| try: | |
| # 检查 client 是否可用 | |
| if client is None: | |
| return fallback_chatbot_response(message) | |
| response = client.chat_completion( | |
| messages=messages, | |
| model="Qwen/Qwen2.5-72B-Instruct", # 支持工具调用的模型 | |
| tools=MCP_TOOLS, | |
| max_tokens=2000, | |
| temperature=0.7 | |
| ) | |
| choice = response.choices[0] | |
| # 检查是否有工具调用 | |
| if choice.message.tool_calls: | |
| # 有工具调用 | |
| messages.append(choice.message) | |
| for tool_call in choice.message.tool_calls: | |
| tool_name = tool_call.function.name | |
| tool_args = json.loads(tool_call.function.arguments) | |
| # 记录工具调用 | |
| tool_calls_log.append({ | |
| "name": tool_name, | |
| "arguments": tool_args | |
| }) | |
| # 调用 MCP 工具 | |
| tool_result = call_mcp_tool(tool_name, tool_args) | |
| # 将工具结果添加到消息列表 | |
| messages.append({ | |
| "role": "tool", | |
| "name": tool_name, | |
| "content": json.dumps(tool_result), | |
| "tool_call_id": tool_call.id | |
| }) | |
| # 继续下一轮对话,让 LLM 处理工具结果 | |
| continue | |
| else: | |
| # 没有工具调用,直接返回回答 | |
| response_text = choice.message.content | |
| break | |
| except Exception as e: | |
| # 如果 LLM API 失败,退回到简单逻辑 | |
| return fallback_chatbot_response(message) | |
| # 构建最终响应 | |
| final_response = "" | |
| # 如果有工具调用,显示调用日志 | |
| if tool_calls_log: | |
| final_response += "**🛠️ MCP Tools Used:**\n\n" | |
| for i, tool_call in enumerate(tool_calls_log, 1): | |
| final_response += f"{i}. `{tool_call['name']}` with arguments: `{json.dumps(tool_call['arguments'])}`\n" | |
| final_response += "\n---\n\n" | |
| final_response += response_text | |
| return final_response | |
| except Exception as e: | |
| import traceback | |
| return f"❌ Error: {str(e)}\n\nTraceback:\n```\n{traceback.format_exc()}\n```" | |
| def fallback_chatbot_response(message): | |
| """退回策略:当 LLM API 不可用时使用的简单逻辑""" | |
| # 检查是否是财务查询相关问题 | |
| financial_keywords = ['financial', 'revenue', 'income', 'earnings', 'cash flow', 'expenses', '财务', '收入', '利润', 'data', 'trend', 'performance'] | |
| if any(keyword in message.lower() for keyword in financial_keywords): | |
| # 提取公司名称和查询类型 | |
| company_keywords = ['apple', 'microsoft', 'nvidia', 'tesla', 'alibaba', 'google', 'amazon', 'meta', 'tsla', 'aapl', 'msft', 'nvda', 'googl', 'amzn'] | |
| detected_company = None | |
| for company in company_keywords: | |
| if company in message.lower(): | |
| if company in ['aapl']: detected_company = 'Apple' | |
| elif company in ['msft']: detected_company = 'Microsoft' | |
| elif company in ['nvda']: detected_company = 'NVIDIA' | |
| elif company in ['tsla']: detected_company = 'Tesla' | |
| elif company in ['googl']: detected_company = 'Google' | |
| elif company in ['amzn']: detected_company = 'Amazon' | |
| else: detected_company = company.capitalize() | |
| break | |
| if detected_company: | |
| # 根据问题内容选择查询类型 | |
| if any(word in message.lower() for word in ['trend', '趋势', 'history', 'historical', 'over time']): | |
| if any(word in message for word in ['5', 'five', '五年']): | |
| query_type = '5-Year Trends' | |
| else: | |
| query_type = '3-Year Trends' | |
| else: | |
| query_type = 'Latest Financial Data' | |
| # 调用财务查询函数 | |
| result = query_financial_data(detected_company, query_type) | |
| return f"Here's the financial information for {detected_company}:\n\n{result}" | |
| else: | |
| return "I can help you query financial data! Please specify a company name. For example: 'Show me Apple's latest financial data' or 'What's NVIDIA's 3-year trend?' \n\nSupported companies include: Apple, Microsoft, NVIDIA, Tesla, Alibaba, Google, Amazon, and more." | |
| # 如果不是财务查询,返回通用回复 | |
| return "Hello! I'm a financial data assistant powered by SEC EDGAR data. I can help you query financial information for US listed companies.\n\n**What I can do:**\n- Get latest financial data (revenue, income, EPS, etc.)\n- Show 3-year or 5-year financial trends\n- Provide detailed financial metrics\n\n**Try asking:**\n- 'Show me Apple's latest financial data'\n- 'What's NVIDIA's 3-year financial trend?'\n- 'How is Microsoft performing financially?'" | |
| # 包装函数,显示加载状态 | |
| def query_with_status(company, query_type): | |
| """Query with loading status indicator""" | |
| try: | |
| # 返回加载状态和结果 | |
| yield "<div style='padding: 10px; background: #e3f2fd; border-left: 4px solid #2196f3; margin: 10px 0;'>🔄 <strong>Loading...</strong> Querying SEC EDGAR data for <strong>{}</strong>...</div>".format(company), "" | |
| # 执行实际查询 | |
| result = query_financial_data(company, query_type) | |
| # 返回成功状态和结果 | |
| yield "<div style='padding: 10px; background: #e8f5e9; border-left: 4px solid #4caf50; margin: 10px 0;'>✅ <strong>Query completed successfully!</strong></div>", result | |
| except Exception as e: | |
| # 返回错误状态 | |
| yield "<div style='padding: 10px; background: #ffebee; border-left: 4px solid #f44336; margin: 10px 0;'>❌ <strong>Error:</strong> {}</div>".format(str(e)), "" | |
| # 创建 Gradio 界面 | |
| with gr.Blocks(title="SEC Financial Data Query Assistant") as demo: | |
| gr.Markdown("# 🤖 SEC Financial Data Query Assistant") | |
| with gr.Tab("AI Assistant"): | |
| # 使用 Gradio ChatInterface(兼容 4.44.1) | |
| chat = gr.ChatInterface( | |
| fn=chatbot_response, | |
| examples=[ | |
| "Show me Apple's latest financial data", | |
| "What's NVIDIA's 3-year financial trend?", | |
| "Get Microsoft's 5-year financial trends", | |
| "How is Tesla performing financially?" | |
| ], | |
| cache_examples=False | |
| ) | |
| with gr.Tab("Direct Query"): | |
| gr.Markdown("## 🔍 Direct Financial Data Query") | |
| gr.Markdown("Select a company and query type to retrieve financial information.") | |
| with gr.Row(): | |
| company_input = gr.Textbox( | |
| label="Company Name or Stock Symbol", | |
| placeholder="e.g., NVIDIA, Apple, Alibaba, AAPL", | |
| scale=2 | |
| ) | |
| query_type = gr.Radio( | |
| ["Latest Financial Data", "3-Year Trends", "5-Year Trends", "Company Filings"], | |
| label="Query Type", | |
| value="Latest Financial Data", | |
| scale=1 | |
| ) | |
| submit_btn = gr.Button("🔍 Query Financial Data", variant="primary", size="lg") | |
| # 添加加载状态指示器 | |
| with gr.Row(): | |
| status_text = gr.Markdown("") | |
| output = gr.Markdown(label="Query Results") | |
| # 示例 | |
| gr.Examples( | |
| examples=[ | |
| ["NVIDIA", "Latest Financial Data"], | |
| ["Apple", "3-Year Trends"], | |
| ["Microsoft", "5-Year Trends"], | |
| ["Alibaba", "Company Filings"], | |
| ["Tesla", "3-Year Trends"] | |
| ], | |
| inputs=[company_input, query_type], | |
| outputs=output, | |
| fn=query_financial_data, | |
| cache_examples=False | |
| ) | |
| submit_btn.click( | |
| fn=query_with_status, | |
| inputs=[company_input, query_type], | |
| outputs=[status_text, output], | |
| show_progress="full" # 显示完整的进度条 | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("**Data Source**: SEC EDGAR | **MCP Server**: https://huggingface.co/spaces/JC321/EasyReportDateMCP") | |
| # Launch the app for Hugging Face Space | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True | |
| ) |