Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| import warnings | |
| from huggingface_hub import InferenceClient | |
| # 抑制 asyncio 警告 | |
| warnings.filterwarnings('ignore', category=DeprecationWarning) | |
| os.environ['PYTHONWARNINGS'] = 'ignore' | |
| # 如果在 GPU 环境但不需要 GPU,禁用 CUDA | |
| if 'CUDA_VISIBLE_DEVICES' not in os.environ: | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '' | |
| # ========== MCP 工具简化定义(符合MCP协议标准) ========== | |
| MCP_TOOLS = [ | |
| {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}}, | |
| {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}}, | |
| {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}}, | |
| {"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}}, | |
| {"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}}, | |
| {"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}} | |
| ] | |
| # ========== MCP 服务配置 ========== | |
| MCP_SERVICES = { | |
| "financial": {"url": "https://jc321-easyreportdatamcp.hf.space/mcp", "type": "fastmcp"}, | |
| "market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"} | |
| } | |
| TOOL_ROUTING = { | |
| "advanced_search_company": MCP_SERVICES["financial"], | |
| "get_latest_financial_data": MCP_SERVICES["financial"], | |
| "extract_financial_metrics": MCP_SERVICES["financial"], | |
| "get_quote": MCP_SERVICES["market"], | |
| "get_market_news": MCP_SERVICES["market"], | |
| "get_company_news": MCP_SERVICES["market"] | |
| } | |
| # ========== 初始化 LLM 客户端 ========== | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient() | |
| print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq") | |
| print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools") | |
| # ========== Token 限制配置 ========== | |
| # HuggingFace Inference API 实际限制约 8000-16000 tokens | |
| # 为了安全,设置更低的限制 | |
| MAX_TOTAL_TOKENS = 6000 # 总上下文限制 | |
| MAX_TOOL_RESULT_CHARS = 1500 # 工具返回最大字符数 (增加到1500) | |
| MAX_HISTORY_CHARS = 500 # 单条历史消息最大字符数 | |
| MAX_HISTORY_TURNS = 2 # 最大历史轮数 | |
| MAX_TOOL_ITERATIONS = 6 # 最大工具调用轮数 (增加到6,支持多工具调用) | |
| MAX_OUTPUT_TOKENS = 2000 # 最大输出 tokens (增加到2000) | |
| def estimate_tokens(text): | |
| """估算文本 token 数量(粗略:1 token ≈ 2 字符)""" | |
| return len(str(text)) // 2 | |
| def truncate_text(text, max_chars, suffix="...[truncated]"): | |
| """截断文本到指定长度""" | |
| text = str(text) | |
| if len(text) <= max_chars: | |
| return text | |
| return text[:max_chars] + suffix | |
| def get_system_prompt(): | |
| """生成包含当前日期的系统提示词(精简版)""" | |
| from datetime import datetime | |
| current_date = datetime.now().strftime("%Y-%m-%d") | |
| return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise.""" | |
| # ============================================================ | |
| # MCP 服务调用核心代码区 | |
| # 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议 | |
| # ============================================================ | |
| def call_mcp_tool(tool_name, arguments): | |
| """调用 MCP 工具""" | |
| service_config = TOOL_ROUTING.get(tool_name) | |
| if not service_config: | |
| return {"error": f"Unknown tool: {tool_name}"} | |
| try: | |
| if service_config["type"] == "fastmcp": | |
| return _call_fastmcp(service_config["url"], tool_name, arguments) | |
| elif service_config["type"] == "gradio": | |
| return _call_gradio_api(service_config["url"], tool_name, arguments) | |
| else: | |
| return {"error": "Unknown service type"} | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def _call_fastmcp(service_url, tool_name, arguments): | |
| """FastMCP: 标准 MCP JSON-RPC""" | |
| response = requests.post( | |
| service_url, | |
| json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1}, | |
| headers={"Content-Type": "application/json"}, | |
| timeout=30 | |
| ) | |
| if response.status_code != 200: | |
| return {"error": f"HTTP {response.status_code}"} | |
| data = response.json() | |
| # 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON | |
| if isinstance(data, dict) and "result" in data: | |
| result = data["result"] | |
| if isinstance(result, dict) and "content" in result: | |
| content = result["content"] | |
| if isinstance(content, list) and len(content) > 0: | |
| first_item = content[0] | |
| if isinstance(first_item, dict) and "text" in first_item: | |
| try: | |
| return json.loads(first_item["text"]) | |
| except (json.JSONDecodeError, TypeError): | |
| return {"text": first_item["text"]} | |
| return result | |
| return data | |
| def _call_gradio_api(service_url, tool_name, arguments): | |
| """Gradio: SSE 流式协议""" | |
| tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"} | |
| gradio_fn = tool_map.get(tool_name) | |
| if not gradio_fn: | |
| return {"error": "No mapping"} | |
| # 构造参数 | |
| if tool_name == "get_quote": | |
| params = [arguments.get("symbol", "")] | |
| elif tool_name == "get_market_news": | |
| params = [arguments.get("category", "general")] | |
| elif tool_name == "get_company_news": | |
| params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")] | |
| else: | |
| params = [] | |
| # 提交请求 | |
| call_url = f"{service_url}/call/{gradio_fn}" | |
| resp = requests.post(call_url, json={"data": params}, timeout=10) | |
| if resp.status_code != 200: | |
| return {"error": f"HTTP {resp.status_code}"} | |
| event_id = resp.json().get("event_id") | |
| if not event_id: | |
| return {"error": "No event_id"} | |
| # 获取结果 (SSE) | |
| result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20) | |
| if result_resp.status_code != 200: | |
| return {"error": f"HTTP {result_resp.status_code}"} | |
| # 解析 SSE | |
| for line in result_resp.iter_lines(): | |
| if line and line.decode('utf-8').startswith('data: '): | |
| try: | |
| result_data = json.loads(line.decode('utf-8')[6:]) | |
| if isinstance(result_data, list) and len(result_data) > 0: | |
| return {"text": result_data[0]} | |
| except json.JSONDecodeError: | |
| continue | |
| return {"error": "No result"} | |
| # ============================================================ | |
| # End of MCP 服务调用代码区 | |
| # ============================================================ | |
| def chatbot_response(message, history): | |
| """AI 助手主函数(流式输出,性能优化)""" | |
| try: | |
| messages = [{"role": "system", "content": get_system_prompt()}] | |
| # 添加历史(最近2轮) - 严格限制上下文长度 | |
| if history: | |
| for item in history[-MAX_HISTORY_TURNS:]: | |
| if isinstance(item, (list, tuple)) and len(item) == 2: | |
| # 用户消息(不截断) | |
| messages.append({"role": "user", "content": item[0]}) | |
| # 助手回复(严格截断) | |
| assistant_msg = str(item[1]) | |
| if len(assistant_msg) > MAX_HISTORY_CHARS: | |
| assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| tool_calls_log = [] | |
| # LLM 调用循环(支持多轮工具调用) | |
| final_response_content = None | |
| for iteration in range(MAX_TOOL_ITERATIONS): | |
| response = client.chat.completions.create( | |
| model="Qwen/Qwen3-32B:groq", | |
| messages=messages, | |
| tools=MCP_TOOLS, | |
| max_tokens=MAX_OUTPUT_TOKENS, | |
| temperature=0.5, | |
| tool_choice="auto", | |
| stream=False | |
| ) | |
| choice = response.choices[0] | |
| if choice.message.tool_calls: | |
| messages.append(choice.message) | |
| for tool_call in choice.message.tool_calls: | |
| tool_name = tool_call.function.name | |
| try: | |
| tool_args = json.loads(tool_call.function.arguments) | |
| except json.JSONDecodeError: | |
| tool_args = {} | |
| # 调用 MCP 工具 | |
| tool_result = call_mcp_tool(tool_name, tool_args) | |
| # 检查错误 | |
| if isinstance(tool_result, dict) and "error" in tool_result: | |
| # 工具调用失败,记录错误 | |
| tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True}) | |
| result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False) | |
| else: | |
| # 限制返回结果大小 | |
| result_str = json.dumps(tool_result, ensure_ascii=False) | |
| if len(result_str) > MAX_TOOL_RESULT_CHARS: | |
| if isinstance(tool_result, dict) and "text" in tool_result: | |
| truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50) | |
| tool_result_truncated = {"text": truncated_text, "_truncated": True} | |
| elif isinstance(tool_result, dict): | |
| truncated = {} | |
| char_count = 0 | |
| for k, v in list(tool_result.items())[:8]: # 保留前8个字段 | |
| v_str = str(v)[:300] # 每个值最多300字符 | |
| truncated[k] = v_str | |
| char_count += len(k) + len(v_str) | |
| if char_count > MAX_TOOL_RESULT_CHARS: | |
| break | |
| tool_result_truncated = {**truncated, "_truncated": True} | |
| else: | |
| tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True} | |
| result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False) | |
| else: | |
| result_for_llm = result_str | |
| # 记录成功的工具调用 | |
| tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result}) | |
| messages.append({ | |
| "role": "tool", | |
| "name": tool_name, | |
| "content": result_for_llm, | |
| "tool_call_id": tool_call.id | |
| }) | |
| continue | |
| else: | |
| # 没有更多工具调用,保存最终答案 | |
| final_response_content = choice.message.content | |
| break | |
| # 构建响应前缀(简化版) | |
| response_prefix = "" | |
| # 显示工具调用(使用原生HTML details标签) | |
| if tool_calls_log: | |
| response_prefix += """<div style='margin-bottom: 15px;'> | |
| <div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'> | |
| 🛠️ Tools Used ({} calls) | |
| </div> | |
| """.format(len(tool_calls_log)) | |
| for idx, tool_call in enumerate(tool_calls_log): | |
| # 预先计算 JSON 字符串,避免重复调用 | |
| args_json = json.dumps(tool_call['arguments'], ensure_ascii=False) | |
| result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2) | |
| result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '') | |
| # 显示错误状态 | |
| error_indicator = " ❌ Error" if tool_call.get('error') else "" | |
| # 使用原生 HTML5 details/summary 标签(不需要 JavaScript) | |
| response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'> | |
| <summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'> | |
| <div style='display: flex; justify-content: space-between; align-items: center;'> | |
| <div style='flex: 1;'> | |
| <strong style='color: #2c5aa0;'>📌 {idx+1}. {tool_call['name']}{error_indicator}</strong> | |
| <div style='font-size: 0.85em; color: #666; margin-top: 4px;'>📥 Input: <code style='background: #f5f5f5; padding: 2px 6px; border-radius: 3px;'>{args_json}</code></div> | |
| </div> | |
| <span style='font-size: 1.2em; color: #999; margin-left: 10px;'>▶</span> | |
| </div> | |
| </summary> | |
| <div style='background: #f9f9f9; padding: 12px; border-top: 1px solid #eee;'> | |
| <div style='font-size: 0.9em; color: #333;'> | |
| <strong>📤 Output:</strong> | |
| <pre style='background: #fff; padding: 10px; border-radius: 4px; overflow-x: auto; margin-top: 6px; font-size: 0.85em; border: 1px solid #e0e0e0; max-height: 400px; white-space: pre-wrap;'>{result_preview}</pre> | |
| </div> | |
| </div> | |
| </details> | |
| """ | |
| response_prefix += """</div> | |
| --- | |
| """ | |
| response_prefix += "\n" | |
| # 流式输出最终答案 | |
| yield response_prefix | |
| # 如果已经有最终答案,直接输出 | |
| if final_response_content: | |
| # 已经从循环中获得了最终答案,直接输出 | |
| yield response_prefix + final_response_content | |
| else: | |
| # 如果循环结束但没有最终答案(达到最大迭代次数),需要再调用一次让模型总结 | |
| try: | |
| stream = client.chat.completions.create( | |
| model="Qwen/Qwen3-32B:groq", | |
| messages=messages, | |
| tools=None, # 不再允许调用工具 | |
| max_tokens=MAX_OUTPUT_TOKENS, | |
| temperature=0.5, | |
| stream=True | |
| ) | |
| accumulated_text = "" | |
| for chunk in stream: | |
| if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content: | |
| accumulated_text += chunk.choices[0].delta.content | |
| yield response_prefix + accumulated_text | |
| except Exception as stream_error: | |
| # 流式输出失败,尝试非流式 | |
| final_resp = client.chat.completions.create( | |
| model="Qwen/Qwen3-32B:groq", | |
| messages=messages, | |
| tools=None, | |
| max_tokens=MAX_OUTPUT_TOKENS, | |
| temperature=0.5, | |
| stream=False | |
| ) | |
| yield response_prefix + final_resp.choices[0].message.content | |
| except Exception as e: | |
| import traceback | |
| error_detail = str(e) | |
| if "500" in error_detail: | |
| yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}" | |
| else: | |
| yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}" | |
| # ========== Gradio 界面(极简版)========== | |
| with gr.Blocks(title="Financial AI Assistant") as demo: | |
| gr.Markdown("# 💬 Financial AI Assistant") | |
| chat = gr.ChatInterface( | |
| fn=chatbot_response, | |
| examples=[ | |
| "What's Apple's latest revenue and profit?", | |
| "Show me NVIDIA's 3-year financial trends", | |
| "How is Tesla's stock performing today?", | |
| "Get the latest market news about crypto", | |
| "Compare Microsoft's latest earnings with its current stock price", | |
| ], | |
| chatbot=gr.Chatbot(height=700), | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| import sys | |
| # 修复 asyncio 事件循环问题 | |
| if sys.platform == 'linux': | |
| try: | |
| import asyncio | |
| asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) | |
| except: | |
| pass | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ssr_mode=False, | |
| quiet=False | |
| ) | |