Spaces:
Sleeping
Sleeping
| from contextlib import AsyncExitStack | |
| from accounts_client import read_accounts_resource, read_strategy_resource | |
| from tracers import make_trace_id | |
| from agents import Agent, Tool, Runner, OpenAIChatCompletionsModel, trace | |
| from openai import AsyncOpenAI | |
| from dotenv import load_dotenv | |
| import os | |
| import json | |
| from agents.mcp import MCPServerStdio | |
| from templates import ( | |
| researcher_instructions, | |
| trader_instructions, | |
| trade_message, | |
| rebalance_message, | |
| research_tool, | |
| ) | |
| from mcp_params import trader_mcp_server_params, researcher_mcp_server_params | |
| load_dotenv(override=True) | |
| deepseek_api_key = os.getenv("DEEPSEEK_API_KEY") | |
| google_api_key = os.getenv("GOOGLE_API_KEY") | |
| grok_api_key = os.getenv("GROK_API_KEY") | |
| openrouter_api_key = os.getenv("OPENROUTER_API_KEY") | |
| DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1" | |
| GROK_BASE_URL = "https://api.x.ai/v1" | |
| GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" | |
| MAX_TURNS = 30 | |
| openrouter_client = AsyncOpenAI(base_url=OPENROUTER_BASE_URL, api_key=openrouter_api_key) | |
| deepseek_client = AsyncOpenAI(base_url=DEEPSEEK_BASE_URL, api_key=deepseek_api_key) | |
| grok_client = AsyncOpenAI(base_url=GROK_BASE_URL, api_key=grok_api_key) | |
| gemini_client = AsyncOpenAI(base_url=GEMINI_BASE_URL, api_key=google_api_key) | |
| def get_model(model_name: str): | |
| if "/" in model_name: | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=openrouter_client) | |
| elif "deepseek" in model_name: | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=deepseek_client) | |
| elif "grok" in model_name: | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=grok_client) | |
| elif "gemini" in model_name: | |
| return OpenAIChatCompletionsModel(model=model_name, openai_client=gemini_client) | |
| else: | |
| return model_name | |
| async def get_researcher(mcp_servers, model_name) -> Agent: | |
| researcher = Agent( | |
| name="Researcher", | |
| instructions=researcher_instructions(), | |
| model=get_model(model_name), | |
| mcp_servers=mcp_servers, | |
| ) | |
| return researcher | |
| async def get_researcher_tool(mcp_servers, model_name) -> Tool: | |
| researcher = await get_researcher(mcp_servers, model_name) | |
| return researcher.as_tool(tool_name="Researcher", tool_description=research_tool()) | |
| class Trader: | |
| def __init__(self, name: str, lastname="Trader", model_name="gpt-4o-mini"): | |
| self.name = name | |
| self.lastname = lastname | |
| self.agent = None | |
| self.model_name = model_name | |
| self.do_trade = True | |
| async def create_agent(self, trader_mcp_servers, researcher_mcp_servers) -> Agent: | |
| tool = await get_researcher_tool(researcher_mcp_servers, self.model_name) | |
| self.agent = Agent( | |
| name=self.name, | |
| instructions=trader_instructions(self.name), | |
| model=get_model(self.model_name), | |
| tools=[tool], | |
| mcp_servers=trader_mcp_servers, | |
| ) | |
| return self.agent | |
| async def get_account_report(self) -> str: | |
| account = await read_accounts_resource(self.name) | |
| account_json = json.loads(account) | |
| account_json.pop("portfolio_value_time_series", None) | |
| return json.dumps(account_json) | |
| async def run_agent(self, trader_mcp_servers, researcher_mcp_servers): | |
| self.agent = await self.create_agent(trader_mcp_servers, researcher_mcp_servers) | |
| account = await self.get_account_report() | |
| strategy = await read_strategy_resource(self.name) | |
| message = ( | |
| trade_message(self.name, strategy, account) | |
| if self.do_trade | |
| else rebalance_message(self.name, strategy, account) | |
| ) | |
| await Runner.run(self.agent, message, max_turns=MAX_TURNS) | |
| async def run_with_mcp_servers(self): | |
| async with AsyncExitStack() as stack: | |
| trader_mcp_servers = [ | |
| await stack.enter_async_context( | |
| MCPServerStdio(params, client_session_timeout_seconds=120) | |
| ) | |
| for params in trader_mcp_server_params | |
| ] | |
| async with AsyncExitStack() as stack: | |
| researcher_mcp_servers = [ | |
| await stack.enter_async_context( | |
| MCPServerStdio(params, client_session_timeout_seconds=120) | |
| ) | |
| for params in researcher_mcp_server_params(self.name) | |
| ] | |
| await self.run_agent(trader_mcp_servers, researcher_mcp_servers) | |
| async def run_with_trace(self): | |
| trace_name = f"{self.name}-trading" if self.do_trade else f"{self.name}-rebalancing" | |
| trace_id = make_trace_id(f"{self.name.lower()}") | |
| with trace(trace_name, trace_id=trace_id): | |
| await self.run_with_mcp_servers() | |
| async def run(self): | |
| try: | |
| await self.run_with_trace() | |
| except Exception as e: | |
| print(f"Error running trader {self.name}: {e}") | |
| self.do_trade = not self.do_trade | |