JC321's picture
Update app.py
83e5dd0 verified
raw
history blame
13 kB
import gradio as gr
import requests
import json
import os
import warnings
from huggingface_hub import InferenceClient
# 抑制 asyncio 警告
warnings.filterwarnings('ignore', category=DeprecationWarning)
os.environ['PYTHONWARNINGS'] = 'ignore'
# 如果在 GPU 环境但不需要 GPU,禁用 CUDA
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# ========== MCP 工具简化定义(符合MCP协议标准) ==========
MCP_TOOLS = [
{"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}},
{"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}},
{"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}},
{"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}},
{"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}},
{"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}}
]
# ========== MCP 服务配置 ==========
MCP_SERVICES = {
"financial": {"url": "https://jc321-easyreportdatemcp.hf.space/mcp", "type": "fastmcp"},
"market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"}
}
TOOL_ROUTING = {
"advanced_search_company": MCP_SERVICES["financial"],
"get_latest_financial_data": MCP_SERVICES["financial"],
"extract_financial_metrics": MCP_SERVICES["financial"],
"get_quote": MCP_SERVICES["market"],
"get_market_news": MCP_SERVICES["market"],
"get_company_news": MCP_SERVICES["market"]
}
# ========== 初始化 LLM 客户端 ==========
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient()
print(f"✅ LLM initialized: Qwen/Qwen2.5-72B-Instruct:novita")
print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools")
# ========== 系统提示词(简化) ==========
from datetime import datetime
def get_system_prompt():
"""生成包含当前日期的系统提示词"""
current_date = datetime.now().strftime("%Y-%m-%d")
return f"""You are a financial analysis assistant. Use tools to get data on company financials (past 5-year reports), current stock prices, market news, and company news. Provide data-driven insights.
IMPORTANT: Today's date is {current_date}. When querying news or time-sensitive data, use recent dates relative to today."""
# ============================================================
# MCP 服务调用核心代码区
# 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议
# ============================================================
def call_mcp_tool(tool_name, arguments):
"""调用 MCP 工具"""
service_config = TOOL_ROUTING.get(tool_name)
if not service_config:
return {"error": f"Unknown tool: {tool_name}"}
try:
if service_config["type"] == "fastmcp":
return _call_fastmcp(service_config["url"], tool_name, arguments)
elif service_config["type"] == "gradio":
return _call_gradio_api(service_config["url"], tool_name, arguments)
else:
return {"error": "Unknown service type"}
except Exception as e:
return {"error": str(e)}
def _call_fastmcp(service_url, tool_name, arguments):
"""FastMCP: 标准 MCP JSON-RPC"""
response = requests.post(
service_url,
json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1},
headers={"Content-Type": "application/json"},
timeout=30
)
if response.status_code != 200:
return {"error": f"HTTP {response.status_code}"}
data = response.json()
# 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON
if isinstance(data, dict) and "result" in data:
result = data["result"]
if isinstance(result, dict) and "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
first_item = content[0]
if isinstance(first_item, dict) and "text" in first_item:
try:
return json.loads(first_item["text"])
except (json.JSONDecodeError, TypeError):
return {"text": first_item["text"]}
return result
return data
def _call_gradio_api(service_url, tool_name, arguments):
"""Gradio: SSE 流式协议"""
tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"}
gradio_fn = tool_map.get(tool_name)
if not gradio_fn:
return {"error": "No mapping"}
# 构造参数
if tool_name == "get_quote":
params = [arguments.get("symbol", "")]
elif tool_name == "get_market_news":
params = [arguments.get("category", "general")]
elif tool_name == "get_company_news":
params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")]
else:
params = []
# 提交请求
call_url = f"{service_url}/call/{gradio_fn}"
resp = requests.post(call_url, json={"data": params}, timeout=10)
if resp.status_code != 200:
return {"error": f"HTTP {resp.status_code}"}
event_id = resp.json().get("event_id")
if not event_id:
return {"error": "No event_id"}
# 获取结果 (SSE)
result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20)
if result_resp.status_code != 200:
return {"error": f"HTTP {result_resp.status_code}"}
# 解析 SSE
for line in result_resp.iter_lines():
if line and line.decode('utf-8').startswith('data: '):
try:
result_data = json.loads(line.decode('utf-8')[6:])
if isinstance(result_data, list) and len(result_data) > 0:
return {"text": result_data[0]}
except json.JSONDecodeError:
continue
return {"error": "No result"}
# ============================================================
# End of MCP 服务调用代码区
# ============================================================
def chatbot_response(message, history):
"""AI 助手主函数(流式输出,性能优化)"""
try:
messages = [{"role": "system", "content": get_system_prompt()}]
# 添加历史(最近5轮)
if history:
for item in history[-5:]:
if isinstance(item, (list, tuple)) and len(item) == 2:
messages.append({"role": "user", "content": item[0]})
messages.append({"role": "assistant", "content": item[1]})
messages.append({"role": "user", "content": message})
tool_calls_log = []
# LLM 调用循环(最多5轮工具调用)
for iteration in range(5):
response = client.chat_completion(
messages=messages,
model="Qwen/Qwen2.5-72B-Instruct:novita",
tools=MCP_TOOLS,
max_tokens=2000,
temperature=0.5,
tool_choice="auto",
stream=False
)
choice = response.choices[0]
if choice.message.tool_calls:
messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
# 调用 MCP 工具
tool_result = call_mcp_tool(tool_name, tool_args)
# 限制返回结果大小,避免超长内容导致500错误
result_str = json.dumps(tool_result, ensure_ascii=False)
if len(result_str) > 4000:
# 截断过长的结果
tool_result_truncated = {"_truncated": True, "preview": result_str[:4000] + "..."}
result_for_llm = json.dumps(tool_result_truncated)
else:
result_for_llm = result_str
# 记录工具调用(包含结果)
tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
messages.append({
"role": "tool",
"name": tool_name,
"content": result_for_llm,
"tool_call_id": tool_call.id
})
continue
else:
break
# 构建响应前缀(简化版)
response_prefix = ""
# 显示工具调用(简单文本+可展开输出)
if tool_calls_log:
response_prefix += "**🛠️ Tools Used:**\n\n"
for idx, tool_call in enumerate(tool_calls_log):
# 工具名称和输入
response_prefix += f"**{idx+1}. `{tool_call['name']}`**\n"
response_prefix += f"- 📥 Input: `{json.dumps(tool_call['arguments'], ensure_ascii=False)}`\n"
# 输出(可展开)
if 'result' in tool_call:
result_str = json.dumps(tool_call['result'], ensure_ascii=False, indent=2)
# 截断过长结果
if len(result_str) > 500:
result_preview = result_str[:500] + "..."
else:
result_preview = result_str
response_prefix += f"<details>\n<summary>📤 Output (click to expand)</summary>\n\n```json\n{result_preview}\n```\n</details>\n\n"
else:
response_prefix += "\n"
response_prefix += "---\n\n"
# 流式输出最终答案
yield response_prefix
stream = client.chat_completion(
messages=messages,
model="Qwen/Qwen2.5-72B-Instruct:novita",
tools=MCP_TOOLS,
max_tokens=2000,
temperature=0.5,
stream=True
)
accumulated_text = ""
for chunk in stream:
if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
accumulated_text += chunk.choices[0].delta.content
yield response_prefix + accumulated_text
except Exception as e:
import traceback
error_detail = str(e)
if "500" in error_detail:
yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}"
else:
yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"
# ========== Gradio 界面(极简版)==========
with gr.Blocks(title="Financial AI Assistant") as demo:
gr.Markdown("# 💬 Financial AI Assistant")
chat = gr.ChatInterface(
fn=chatbot_response,
examples=[
"What's Apple's latest revenue and profit?",
"Show me NVIDIA's 3-year financial trends",
"How is Tesla's stock performing today?",
"Get the latest market news about crypto",
"Compare Microsoft's latest earnings with its current stock price",
],
chatbot=gr.Chatbot(height=600),
)
# 启动应用
if __name__ == "__main__":
import sys
import asyncio
# 修复 asyncio 事件循环问题
if sys.platform == 'linux':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
except:
pass
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
ssr_mode=False,
quiet=False
)