stonks / app.py
thomfoolery's picture
add mcp
79a4fbe
import json
from mcp import ClientSession
import anthropic
import chainlit as cl
anthropic_client = anthropic.AsyncAnthropic()
SYSTEM = "You are a helpful stock trader that help me identify opportunities and risks in the stock market. " \
"Provide brief and concise reports of current market conditions, including key indicators and trends. " \
"Use the tools provided to gather data and insights. " \
"If you need to use a tool, please do so and provide the results in your response. " \
"You can also ask clarifying questions if needed. " \
"You try to maximize the profit of the user. "
def flatten(xss):
return [x for xs in xss for x in xs]
@cl.on_mcp_connect
async def on_mcp(connection, session: ClientSession):
result = await session.list_tools()
tools = [{
"name": t.name,
"description": t.description,
"input_schema": t.inputSchema,
} for t in result.tools]
mcp_tools = cl.user_session.get("mcp_tools", {})
mcp_tools[connection.name] = tools
cl.user_session.set("mcp_tools", mcp_tools)
@cl.step(type="tool")
async def call_tool(tool_use):
tool_name = tool_use.name
tool_input = tool_use.input
current_step = cl.context.current_step
current_step.name = tool_name
# Identify which mcp is used
mcp_tools = cl.user_session.get("mcp_tools", {})
mcp_name = None
for connection_name, tools in mcp_tools.items():
if any(tool.get("name") == tool_name for tool in tools):
mcp_name = connection_name
break
if not mcp_name:
current_step.output = json.dumps({"error": f"Tool {tool_name} not found in any MCP connection"})
return current_step.output
mcp_session, _ = cl.context.session.mcp_sessions.get(mcp_name)
if not mcp_session:
current_step.output = json.dumps({"error": f"MCP {mcp_name} not found in any MCP connection"})
return current_step.output
try:
current_step.output = await mcp_session.call_tool(tool_name, tool_input)
except Exception as e:
current_step.output = json.dumps({"error": str(e)})
return current_step.output
async def call_claude(chat_messages):
msg = cl.Message(content="")
mcp_tools = cl.user_session.get("mcp_tools", {})
# Flatten the tools from all MCP connections
tools = flatten([tools for _, tools in mcp_tools.items()])
async with anthropic_client.messages.stream(
system=SYSTEM,
max_tokens=1024,
messages=chat_messages,
tools=tools,
model="claude-3-5-sonnet-20240620",
) as stream:
async for text in stream.text_stream:
await msg.stream_token(text)
await msg.send()
response = await stream.get_final_message()
return response
@cl.on_chat_start
async def start_chat():
cl.user_session.set("chat_messages", [])
@cl.on_message
async def on_message(msg: cl.Message):
chat_messages = cl.user_session.get("chat_messages")
chat_messages.append({"role": "user", "content": msg.content})
response = await call_claude(chat_messages)
while response.stop_reason == "tool_use":
tool_use = next(block for block in response.content if block.type == "tool_use")
tool_result = await call_tool(tool_use)
messages = [
{"role": "assistant", "content": response.content},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": str(tool_result),
}
],
},
]
chat_messages.extend(messages)
response = await call_claude(chat_messages)
final_response = next(
(block.text for block in response.content if hasattr(block, "text")),
None,
)
chat_messages = cl.user_session.get("chat_messages")
chat_messages.append({"role": "assistant", "content": final_response})