Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| import warnings | |
| from huggingface_hub import InferenceClient | |
| # 抑制 asyncio 警告 | |
| warnings.filterwarnings('ignore', category=DeprecationWarning) | |
| os.environ['PYTHONWARNINGS'] = 'ignore' | |
| # 如果在 GPU 环境但不需要 GPU,禁用 CUDA | |
| if 'CUDA_VISIBLE_DEVICES' not in os.environ: | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| # ========== MCP 工具简化定义(符合MCP协议标准) ========== | |
| MCP_TOOLS = [ | |
| {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}}, | |
| {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}}, | |
| {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}}, | |
| {"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}, | |
| {"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}}, | |
| {"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}} | |
| ] | |
| # ========== MCP 服务配置 ========== | |
| MCP_SERVICES = { | |
| "financial": {"url": "https://jc321-easyreportdatemcp.hf.space/mcp", "type": "fastmcp"}, | |
| "market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"} | |
| } | |
| TOOL_ROUTING = { | |
| "advanced_search_company": MCP_SERVICES["financial"], | |
| "get_latest_financial_data": MCP_SERVICES["financial"], | |
| "extract_financial_metrics": MCP_SERVICES["financial"], | |
| "get_quote": MCP_SERVICES["market"], | |
| "get_market_news": MCP_SERVICES["market"], | |
| "get_company_news": MCP_SERVICES["market"] | |
| } | |
| # ========== 初始化 LLM 客户端 ========== | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient() | |
| print(f"✅ LLM initialized: Qwen/Qwen2.5-72B-Instruct:novita") | |
| print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools") | |
| # ========== 系统提示词(简化) ========== | |
| from datetime import datetime | |
| def get_system_prompt(): | |
| """生成包含当前日期的系统提示词""" | |
| current_date = datetime.now().strftime("%Y-%m-%d") | |
| return f"""You are a financial analysis assistant. Use tools to get data on company financials (past 5-year reports), current stock prices, market news, and company news. Provide data-driven insights. | |
| IMPORTANT: Today's date is {current_date}. When querying news or time-sensitive data, use recent dates relative to today.""" | |
| # ============================================================ | |
| # MCP 服务调用核心代码区 | |
| # 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议 | |
| # ============================================================ | |
| def call_mcp_tool(tool_name, arguments): | |
| """调用 MCP 工具""" | |
| service_config = TOOL_ROUTING.get(tool_name) | |
| if not service_config: | |
| return {"error": f"Unknown tool: {tool_name}"} | |
| try: | |
| if service_config["type"] == "fastmcp": | |
| return _call_fastmcp(service_config["url"], tool_name, arguments) | |
| elif service_config["type"] == "gradio": | |
| return _call_gradio_api(service_config["url"], tool_name, arguments) | |
| else: | |
| return {"error": "Unknown service type"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def _call_fastmcp(service_url, tool_name, arguments): | |
| """FastMCP: 标准 MCP JSON-RPC""" | |
| response = requests.post( | |
| service_url, | |
| json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1}, | |
| headers={"Content-Type": "application/json"}, | |
| timeout=30 | |
| ) | |
| if response.status_code != 200: | |
| return {"error": f"HTTP {response.status_code}"} | |
| data = response.json() | |
| # 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON | |
| if isinstance(data, dict) and "result" in data: | |
| result = data["result"] | |
| if isinstance(result, dict) and "content" in result: | |
| content = result["content"] | |
| if isinstance(content, list) and len(content) > 0: | |
| first_item = content[0] | |
| if isinstance(first_item, dict) and "text" in first_item: | |
| try: | |
| return json.loads(first_item["text"]) | |
| except (json.JSONDecodeError, TypeError): | |
| return {"text": first_item["text"]} | |
| return result | |
| return data | |
| def _call_gradio_api(service_url, tool_name, arguments): | |
| """Gradio: SSE 流式协议""" | |
| tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"} | |
| gradio_fn = tool_map.get(tool_name) | |
| if not gradio_fn: | |
| return {"error": "No mapping"} | |
| # 构造参数 | |
| if tool_name == "get_quote": | |
| params = [arguments.get("symbol", "")] | |
| elif tool_name == "get_market_news": | |
| params = [arguments.get("category", "general")] | |
| elif tool_name == "get_company_news": | |
| params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")] | |
| else: | |
| params = [] | |
| # 提交请求 | |
| call_url = f"{service_url}/call/{gradio_fn}" | |
| resp = requests.post(call_url, json={"data": params}, timeout=10) | |
| if resp.status_code != 200: | |
| return {"error": f"HTTP {resp.status_code}"} | |
| event_id = resp.json().get("event_id") | |
| if not event_id: | |
| return {"error": "No event_id"} | |
| # 获取结果 (SSE) | |
| result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20) | |
| if result_resp.status_code != 200: | |
| return {"error": f"HTTP {result_resp.status_code}"} | |
| # 解析 SSE | |
| for line in result_resp.iter_lines(): | |
| if line and line.decode('utf-8').startswith('data: '): | |
| try: | |
| result_data = json.loads(line.decode('utf-8')[6:]) | |
| if isinstance(result_data, list) and len(result_data) > 0: | |
| return {"text": result_data[0]} | |
| except json.JSONDecodeError: | |
| continue | |
| return {"error": "No result"} | |
| # ============================================================ | |
| # End of MCP 服务调用代码区 | |
| # ============================================================ | |
| def chatbot_response(message, history): | |
| """AI 助手主函数(流式输出,性能优化)""" | |
| try: | |
| messages = [{"role": "system", "content": get_system_prompt()}] | |
| # 添加历史(最近5轮) | |
| if history: | |
| for item in history[-5:]: | |
| if isinstance(item, (list, tuple)) and len(item) == 2: | |
| messages.append({"role": "user", "content": item[0]}) | |
| messages.append({"role": "assistant", "content": item[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| tool_calls_log = [] | |
| # LLM 调用循环(最多5轮工具调用) | |
| for iteration in range(5): | |
| response = client.chat_completion( | |
| messages=messages, | |
| model="Qwen/Qwen2.5-72B-Instruct:novita", | |
| tools=MCP_TOOLS, | |
| max_tokens=2000, | |
| temperature=0.5, | |
| tool_choice="auto", | |
| stream=False | |
| ) | |
| choice = response.choices[0] | |
| if choice.message.tool_calls: | |
| messages.append(choice.message) | |
| for tool_call in choice.message.tool_calls: | |
| tool_name = tool_call.function.name | |
| tool_args = json.loads(tool_call.function.arguments) | |
| # 调用 MCP 工具 | |
| tool_result = call_mcp_tool(tool_name, tool_args) | |
| # 限制返回结果大小,避免超长内容导致500错误 | |
| result_str = json.dumps(tool_result, ensure_ascii=False) | |
| if len(result_str) > 4000: | |
| # 截断过长的结果 | |
| tool_result_truncated = {"_truncated": True, "preview": result_str[:4000] + "..."} | |
| result_for_llm = json.dumps(tool_result_truncated) | |
| else: | |
| result_for_llm = result_str | |
| # 记录工具调用(包含结果) | |
| tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) | |
| messages.append({ | |
| "role": "tool", | |
| "name": tool_name, | |
| "content": result_for_llm, | |
| "tool_call_id": tool_call.id | |
| }) | |
| continue | |
| else: | |
| break | |
| # 构建响应前缀(简化版) | |
| response_prefix = "" | |
| # 显示工具调用(简单文本+可展开输出) | |
| if tool_calls_log: | |
| response_prefix += "**🛠️ Tools Used:**\n\n" | |
| for idx, tool_call in enumerate(tool_calls_log): | |
| # 工具名称和输入 | |
| response_prefix += f"**{idx+1}. `{tool_call['name']}`**\n" | |
| response_prefix += f"- 📥 Input: `{json.dumps(tool_call['arguments'], ensure_ascii=False)}`\n" | |
| # 输出(可展开) | |
| if 'result' in tool_call: | |
| result_str = json.dumps(tool_call['result'], ensure_ascii=False, indent=2) | |
| # 截断过长结果 | |
| if len(result_str) > 500: | |
| result_preview = result_str[:500] + "..." | |
| else: | |
| result_preview = result_str | |
| response_prefix += f"<details>\n<summary>📤 Output (click to expand)</summary>\n\n```json\n{result_preview}\n```\n</details>\n\n" | |
| else: | |
| response_prefix += "\n" | |
| response_prefix += "---\n\n" | |
| # 流式输出最终答案 | |
| yield response_prefix | |
| stream = client.chat_completion( | |
| messages=messages, | |
| model="Qwen/Qwen2.5-72B-Instruct:novita", | |
| tools=MCP_TOOLS, | |
| max_tokens=2000, | |
| temperature=0.5, | |
| stream=True | |
| ) | |
| accumulated_text = "" | |
| for chunk in stream: | |
| if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content: | |
| accumulated_text += chunk.choices[0].delta.content | |
| yield response_prefix + accumulated_text | |
| except Exception as e: | |
| import traceback | |
| error_detail = str(e) | |
| if "500" in error_detail: | |
| yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}" | |
| else: | |
| yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}" | |
| # ========== Gradio 界面(极简版)========== | |
| with gr.Blocks(title="Financial AI Assistant") as demo: | |
| gr.Markdown("# 💬 Financial AI Assistant") | |
| chat = gr.ChatInterface( | |
| fn=chatbot_response, | |
| examples=[ | |
| "What's Apple's latest revenue and profit?", | |
| "Show me NVIDIA's 3-year financial trends", | |
| "How is Tesla's stock performing today?", | |
| "Get the latest market news about crypto", | |
| "Compare Microsoft's latest earnings with its current stock price", | |
| ], | |
| chatbot=gr.Chatbot(height=600), | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| import sys | |
| import asyncio | |
| # 修复 asyncio 事件循环问题 | |
| if sys.platform == 'linux': | |
| try: | |
| import asyncio | |
| asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) | |
| except: | |
| pass | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ssr_mode=False, | |
| quiet=False | |
| ) | |