trading_floor / traders.py
Denis Mbugua
configure
778e241
from contextlib import AsyncExitStack
from accounts_client import read_accounts_resource, read_strategy_resource
from tracers import make_trace_id
from agents import Agent, Tool, Runner, OpenAIChatCompletionsModel, trace
from openai import AsyncOpenAI
from dotenv import load_dotenv
import os
import json
from agents.mcp import MCPServerStdio
from templates import (
researcher_instructions,
trader_instructions,
trade_message,
rebalance_message,
research_tool,
)
from mcp_params import trader_mcp_server_params, researcher_mcp_server_params
load_dotenv(override=True)
deepseek_api_key = os.getenv("DEEPSEEK_API_KEY")
google_api_key = os.getenv("GOOGLE_API_KEY")
grok_api_key = os.getenv("GROK_API_KEY")
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
GROK_BASE_URL = "https://api.x.ai/v1"
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
MAX_TURNS = 30
openrouter_client = AsyncOpenAI(base_url=OPENROUTER_BASE_URL, api_key=openrouter_api_key)
deepseek_client = AsyncOpenAI(base_url=DEEPSEEK_BASE_URL, api_key=deepseek_api_key)
grok_client = AsyncOpenAI(base_url=GROK_BASE_URL, api_key=grok_api_key)
gemini_client = AsyncOpenAI(base_url=GEMINI_BASE_URL, api_key=google_api_key)
def get_model(model_name: str):
if "/" in model_name:
return OpenAIChatCompletionsModel(model=model_name, openai_client=openrouter_client)
elif "deepseek" in model_name:
return OpenAIChatCompletionsModel(model=model_name, openai_client=deepseek_client)
elif "grok" in model_name:
return OpenAIChatCompletionsModel(model=model_name, openai_client=grok_client)
elif "gemini" in model_name:
return OpenAIChatCompletionsModel(model=model_name, openai_client=gemini_client)
else:
return model_name
async def get_researcher(mcp_servers, model_name) -> Agent:
researcher = Agent(
name="Researcher",
instructions=researcher_instructions(),
model=get_model(model_name),
mcp_servers=mcp_servers,
)
return researcher
async def get_researcher_tool(mcp_servers, model_name) -> Tool:
researcher = await get_researcher(mcp_servers, model_name)
return researcher.as_tool(tool_name="Researcher", tool_description=research_tool())
class Trader:
def __init__(self, name: str, lastname="Trader", model_name="gpt-4o-mini"):
self.name = name
self.lastname = lastname
self.agent = None
self.model_name = model_name
self.do_trade = True
async def create_agent(self, trader_mcp_servers, researcher_mcp_servers) -> Agent:
tool = await get_researcher_tool(researcher_mcp_servers, self.model_name)
self.agent = Agent(
name=self.name,
instructions=trader_instructions(self.name),
model=get_model(self.model_name),
tools=[tool],
mcp_servers=trader_mcp_servers,
)
return self.agent
async def get_account_report(self) -> str:
account = await read_accounts_resource(self.name)
account_json = json.loads(account)
account_json.pop("portfolio_value_time_series", None)
return json.dumps(account_json)
async def run_agent(self, trader_mcp_servers, researcher_mcp_servers):
self.agent = await self.create_agent(trader_mcp_servers, researcher_mcp_servers)
account = await self.get_account_report()
strategy = await read_strategy_resource(self.name)
message = (
trade_message(self.name, strategy, account)
if self.do_trade
else rebalance_message(self.name, strategy, account)
)
await Runner.run(self.agent, message, max_turns=MAX_TURNS)
async def run_with_mcp_servers(self):
async with AsyncExitStack() as stack:
trader_mcp_servers = [
await stack.enter_async_context(
MCPServerStdio(params, client_session_timeout_seconds=120)
)
for params in trader_mcp_server_params
]
async with AsyncExitStack() as stack:
researcher_mcp_servers = [
await stack.enter_async_context(
MCPServerStdio(params, client_session_timeout_seconds=120)
)
for params in researcher_mcp_server_params(self.name)
]
await self.run_agent(trader_mcp_servers, researcher_mcp_servers)
async def run_with_trace(self):
trace_name = f"{self.name}-trading" if self.do_trade else f"{self.name}-rebalancing"
trace_id = make_trace_id(f"{self.name.lower()}")
with trace(trace_name, trace_id=trace_id):
await self.run_with_mcp_servers()
async def run(self):
try:
await self.run_with_trace()
except Exception as e:
print(f"Error running trader {self.name}: {e}")
self.do_trade = not self.do_trade