Spaces:
Sleeping
Sleeping
Midas Deploy Bot commited on
Commit ·
9419f40
0
Parent(s):
Deploy: 8c89a8fbe906551039281c5c5a0089572c945b0b
Browse files- Dockerfile +24 -0
- README.md +9 -0
- __init__.py +0 -0
- agent_factory.py +42 -0
- agents.yaml +39 -0
- application.py +231 -0
- base_agent.py +238 -0
- base_tool.py +36 -0
- llm_provider.py +171 -0
- memory.py +68 -0
- registry.py +29 -0
- requirements.txt +93 -0
- tools.py +147 -0
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Start with Python
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory inside the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# 1. Copy requirements (Relative to current folder, no 'backend/' prefix)
|
| 8 |
+
COPY requirements.txt .
|
| 9 |
+
|
| 10 |
+
# 2. Install dependencies
|
| 11 |
+
# Added --no-cache-dir to keep image size small
|
| 12 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 13 |
+
pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# This ARG changes whenever force rebuild is needed.
|
| 16 |
+
# We don't even need to pass a value; just the existence of a changed line forces a rebuild.
|
| 17 |
+
ARG CACHEBUST=20251217
|
| 18 |
+
|
| 19 |
+
# 3. Copy the rest of the code (Current folder -> /app)
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# 4. Run the application
|
| 23 |
+
# Note: application:app works because application.py is now at /app/application.py
|
| 24 |
+
CMD ["uvicorn", "application:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Midas Backend
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
# Midas Protocol Backend
|
__init__.py
ADDED
|
File without changes
|
agent_factory.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ruamel.yaml as yaml
|
| 2 |
+
from typing import Dict
|
| 3 |
+
from base_agent import SingleAgent
|
| 4 |
+
from registry import ToolRegistry
|
| 5 |
+
|
| 6 |
+
class AgentFactory:
|
| 7 |
+
def __init__(self, registry: ToolRegistry):
|
| 8 |
+
self.registry = registry
|
| 9 |
+
self.parser = yaml.YAML(typ='safe')
|
| 10 |
+
self.parser.pure = True
|
| 11 |
+
|
| 12 |
+
def load_from_yaml(self, config_path: str) -> Dict[str, SingleAgent]:
|
| 13 |
+
"""
|
| 14 |
+
Reads YAML and returns a dictionary of built Agents.
|
| 15 |
+
"""
|
| 16 |
+
with open(config_path, "r") as f:
|
| 17 |
+
config = self.parser.load(f)
|
| 18 |
+
|
| 19 |
+
agents_dict = {}
|
| 20 |
+
|
| 21 |
+
for agent_conf in config["agents"]:
|
| 22 |
+
name = agent_conf["name"]
|
| 23 |
+
subscriptions = agent_conf.get("subscriptions", [])
|
| 24 |
+
sys_prompt = agent_conf.get("system_prompt", "You are a helpful assistant.")
|
| 25 |
+
|
| 26 |
+
unique_tools = set()
|
| 27 |
+
for category in subscriptions:
|
| 28 |
+
tools = self.registry.get_tools_by_category(category)
|
| 29 |
+
for t in tools:
|
| 30 |
+
unique_tools.add(t)
|
| 31 |
+
|
| 32 |
+
agent_tools_list = list(unique_tools)
|
| 33 |
+
|
| 34 |
+
new_agent = SingleAgent(
|
| 35 |
+
name=name,
|
| 36 |
+
tools=agent_tools_list,
|
| 37 |
+
system_prompt=sys_prompt
|
| 38 |
+
)
|
| 39 |
+
agents_dict[name] = new_agent
|
| 40 |
+
|
| 41 |
+
return agents_dict
|
| 42 |
+
|
agents.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
agents:
|
| 2 |
+
- name: "PriceWorker"
|
| 3 |
+
model: "llama-3.3-70b-versatile"
|
| 4 |
+
description: "An expert in financial data and stock market valuations."
|
| 5 |
+
subscriptions:
|
| 6 |
+
- "finance"
|
| 7 |
+
- "utils"
|
| 8 |
+
system_prompt: |
|
| 9 |
+
You are a specialized financial analyst.
|
| 10 |
+
Your goal is to retrieve accurate stock prices and perform precise calculations.
|
| 11 |
+
|
| 12 |
+
CRITICAL TOOL INSTRUCTIONS:
|
| 13 |
+
1. When calling tools, use standard JSON format only.
|
| 14 |
+
2. NEVER use XML tags like <function> or <tool_code>.
|
| 15 |
+
3. Use 'get_stock_price' to find data.
|
| 16 |
+
4. NEVER do math in your head. ALWAYS use the 'calculator' tool.
|
| 17 |
+
5. **NO NESTED TOOLS**: You cannot pass a tool call as an argument to another tool.
|
| 18 |
+
- INCORRECT: calculator(x=get_stock_price("AAPL")...)
|
| 19 |
+
- CORRECT:
|
| 20 |
+
Step 1: Call get_stock_price("AAPL")
|
| 21 |
+
Step 2: Wait for result.
|
| 22 |
+
Step 3: Call calculator(x=200, ...)
|
| 23 |
+
6. Output standard JSON for tool calls.
|
| 24 |
+
|
| 25 |
+
- name: "NewsWorker"
|
| 26 |
+
model: "llama-3.3-70b-versatile"
|
| 27 |
+
description: "An expert in market sentiment and company news."
|
| 28 |
+
subscriptions:
|
| 29 |
+
- "news"
|
| 30 |
+
system_prompt: |
|
| 31 |
+
You are a specialized news reporter.
|
| 32 |
+
Your goal is to summarize the latest stories affecting a company.
|
| 33 |
+
Focus on headlines that might impact stock value.
|
| 34 |
+
Be concise and objective.
|
| 35 |
+
|
| 36 |
+
CRITICAL TOOL INSTRUCTIONS:
|
| 37 |
+
1. When calling tools, use standard JSON format only.
|
| 38 |
+
2. NEVER use XML tags like <function> or <tool_code>.
|
| 39 |
+
3. ALWAYS cite specific headlines in your final answer.
|
application.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Header
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
from memory import TokenBufferMemory
|
| 9 |
+
from agent_factory import AgentFactory
|
| 10 |
+
from base_agent import ManagerAgent
|
| 11 |
+
from tools import initialize_registry
|
| 12 |
+
from llm_provider import GroqProvider, GeminiProvider, QuotaExhaustedError, ProviderDownError, ProviderError
|
| 13 |
+
|
| 14 |
+
import time
|
| 15 |
+
import uvicorn
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(
|
| 23 |
+
level=logging.INFO,
|
| 24 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 25 |
+
handlers=[
|
| 26 |
+
logging.FileHandler("agent_debug.log"),
|
| 27 |
+
logging.StreamHandler(sys.stdout)
|
| 28 |
+
]
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger("SystemBackend")
|
| 31 |
+
|
| 32 |
+
class ChatRequest(BaseModel):
|
| 33 |
+
query: str
|
| 34 |
+
provider: Optional[str] = None
|
| 35 |
+
api_key: Optional[str] = None
|
| 36 |
+
|
| 37 |
+
class ChatResponse(BaseModel):
|
| 38 |
+
success: bool
|
| 39 |
+
response: Optional[str] = None
|
| 40 |
+
provider_used: Optional[str] = None
|
| 41 |
+
agent_used: Optional[str] = None
|
| 42 |
+
error_type: Optional[str] = None
|
| 43 |
+
required_provider: Optional[str] = None
|
| 44 |
+
message: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
app = FastAPI()
|
| 47 |
+
registry = initialize_registry()
|
| 48 |
+
factory = AgentFactory(registry)
|
| 49 |
+
|
| 50 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 51 |
+
YAML_PATH = os.path.join(BASE_DIR, "agents.yaml")
|
| 52 |
+
try:
|
| 53 |
+
workers_map = factory.load_from_yaml(YAML_PATH)
|
| 54 |
+
logger.info(f"✅ Successfully loaded agents: {list(workers_map.keys())}")
|
| 55 |
+
except FileNotFoundError:
|
| 56 |
+
logger.critical("❌ 'agents.yaml' not found! Please create this configuration file.")
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
agent_memory = TokenBufferMemory(max_tokens=4096)
|
| 60 |
+
|
| 61 |
+
manager_agent = ManagerAgent(
|
| 62 |
+
name="Manager",
|
| 63 |
+
sub_agents=workers_map,
|
| 64 |
+
memory=agent_memory
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ProviderStatus(Enum):
|
| 69 |
+
ACTIVE = "Active"
|
| 70 |
+
DOWN = "Down"
|
| 71 |
+
QUOTA_EXHAUSTED = "QuotaExhausted"
|
| 72 |
+
ERROR = "Error"
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class ProviderState:
|
| 76 |
+
name: str
|
| 77 |
+
status: ProviderStatus
|
| 78 |
+
reset_time: float
|
| 79 |
+
|
| 80 |
+
class ProviderManager:
|
| 81 |
+
def __init__(self):
|
| 82 |
+
self.providers = {
|
| 83 |
+
"groq": ProviderState(name="groq", status=ProviderStatus.ACTIVE, reset_time=0.0),
|
| 84 |
+
"gemini": ProviderState(name="gemini", status=ProviderStatus.ACTIVE, reset_time=0.0),
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
def update_status(self, provider_name: str, status: ProviderStatus):
|
| 88 |
+
self.providers[provider_name].status = status
|
| 89 |
+
if status == ProviderStatus.QUOTA_EXHAUSTED:
|
| 90 |
+
self.providers[provider_name].reset_time = time.time() + (24 * 60 * 60) # 24 hours
|
| 91 |
+
elif status == ProviderStatus.DOWN:
|
| 92 |
+
self.providers[provider_name].reset_time = time.time() + 60 # 60 seconds
|
| 93 |
+
|
| 94 |
+
def get_provider(self):
|
| 95 |
+
for name, state in self.providers.items():
|
| 96 |
+
if state.status == ProviderStatus.ACTIVE:
|
| 97 |
+
return name
|
| 98 |
+
elif state.status in [ProviderStatus.DOWN, ProviderStatus.QUOTA_EXHAUSTED]:
|
| 99 |
+
if time.time() > state.reset_time:
|
| 100 |
+
state.status = ProviderStatus.ACTIVE
|
| 101 |
+
return name
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
def is_provider_active(self, provider_name: str) -> bool:
|
| 105 |
+
if provider_name not in self.providers:
|
| 106 |
+
return False
|
| 107 |
+
state = self.providers[provider_name]
|
| 108 |
+
if state.status in [ProviderStatus.DOWN, ProviderStatus.QUOTA_EXHAUSTED]:
|
| 109 |
+
if time.time() > state.reset_time:
|
| 110 |
+
state.status = ProviderStatus.ACTIVE
|
| 111 |
+
return True
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
return state.status == ProviderStatus.ACTIVE
|
| 115 |
+
|
| 116 |
+
provider_manager = ProviderManager()
|
| 117 |
+
|
| 118 |
+
class RequestBody(BaseModel):
|
| 119 |
+
query: str
|
| 120 |
+
provider: Optional[str] = None
|
| 121 |
+
api_key: Optional[str] = None
|
| 122 |
+
|
| 123 |
+
def has_server_key(name: str) -> bool:
|
| 124 |
+
if name == "groq" and os.getenv("GROQ_API_KEY"): return True
|
| 125 |
+
if name == "gemini" and os.getenv("GEMINI_API_KEY"): return True
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
def get_provider_instance(name: str, key: str):
|
| 129 |
+
if name == "groq": return GroqProvider(api_key=key)
|
| 130 |
+
elif name == "gemini": return GeminiProvider(api_key=key)
|
| 131 |
+
raise ValueError(f"Unknown provider: {name}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 135 |
+
async def chat_endpoint(request: ChatRequest):
|
| 136 |
+
|
| 137 |
+
# CASE 1: MANUAL OVERRIDE (User specifically asked for a provider)
|
| 138 |
+
if request.provider:
|
| 139 |
+
target = request.provider.lower()
|
| 140 |
+
|
| 141 |
+
# A. Check if valid/active
|
| 142 |
+
if not provider_manager.is_provider_active(target):
|
| 143 |
+
return ChatResponse(
|
| 144 |
+
success=False,
|
| 145 |
+
error_type="provider_down",
|
| 146 |
+
required_provider=target,
|
| 147 |
+
message=f"Requested provider '{target}' is currently unavailable (Down/Quota)."
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# B. Resolve Key
|
| 151 |
+
final_key = request.api_key if request.api_key else None
|
| 152 |
+
if not final_key and has_server_key(target):
|
| 153 |
+
final_key = os.getenv(f"{target.upper()}_API_KEY")
|
| 154 |
+
|
| 155 |
+
logger.info(f"🔍 DEBUG CHECK: Target={target}, Key_Type={type(final_key)}, Has_Key={bool(final_key)}")
|
| 156 |
+
|
| 157 |
+
if not final_key:
|
| 158 |
+
return ChatResponse(
|
| 159 |
+
success=False,
|
| 160 |
+
error_type="needs_key",
|
| 161 |
+
required_provider=target,
|
| 162 |
+
message=f"API Key missing for {target}."
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# C. Execute (NO LOOP - Fail fast if user preference fails)
|
| 166 |
+
try:
|
| 167 |
+
logger.info(f"🔄 Executing User Preference: {target}")
|
| 168 |
+
llm = get_provider_instance(target, final_key)
|
| 169 |
+
result = manager_agent.process_query(request.query, llm)
|
| 170 |
+
return ChatResponse(
|
| 171 |
+
success=True,
|
| 172 |
+
response=result.content,
|
| 173 |
+
provider_used=target,
|
| 174 |
+
agent_used=manager_agent.name
|
| 175 |
+
)
|
| 176 |
+
except (QuotaExhaustedError, ProviderDownError) as e:
|
| 177 |
+
provider_manager.update_status(target, ProviderStatus.QUOTA_EXHAUSTED) # or check exception type
|
| 178 |
+
return ChatResponse(success=False, error_type="provider_down", message=str(e))
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.error(f"Server Error: {e}")
|
| 181 |
+
return ChatResponse(success=False, error_type="server_error", message=str(e))
|
| 182 |
+
|
| 183 |
+
# CASE 2: AUTO-PILOT (Loop through available providers)
|
| 184 |
+
else:
|
| 185 |
+
while True:
|
| 186 |
+
current = provider_manager.get_provider()
|
| 187 |
+
|
| 188 |
+
if not current:
|
| 189 |
+
return ChatResponse(
|
| 190 |
+
success=False,
|
| 191 |
+
error_type="all_down",
|
| 192 |
+
message="All providers are currently down or exhausted."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Check Key for the Auto-Selected candidate
|
| 196 |
+
final_key = None
|
| 197 |
+
if has_server_key(current):
|
| 198 |
+
final_key = os.getenv(f"{current.upper()}_API_KEY")
|
| 199 |
+
|
| 200 |
+
if not final_key:
|
| 201 |
+
# If Auto-Router picks a provider we have no key for, we must ask the user.
|
| 202 |
+
return ChatResponse(
|
| 203 |
+
success=False,
|
| 204 |
+
error_type="needs_key",
|
| 205 |
+
required_provider=current,
|
| 206 |
+
message=f"Auto-switching to {current}, but API Key is missing."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
logger.info(f"🔄 Auto-Routing via: {current}")
|
| 211 |
+
llm = get_provider_instance(current, final_key)
|
| 212 |
+
result = manager_agent.process_query(request.query, llm)
|
| 213 |
+
return ChatResponse(
|
| 214 |
+
success=True,
|
| 215 |
+
response=result.content,
|
| 216 |
+
provider_used=current,
|
| 217 |
+
agent_used=manager_agent.name
|
| 218 |
+
)
|
| 219 |
+
except QuotaExhaustedError:
|
| 220 |
+
provider_manager.update_status(current, ProviderStatus.QUOTA_EXHAUSTED)
|
| 221 |
+
continue # Try next in loop
|
| 222 |
+
except ProviderDownError:
|
| 223 |
+
provider_manager.update_status(current, ProviderStatus.DOWN)
|
| 224 |
+
continue # Try next in loop
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Critical Error: {e}")
|
| 227 |
+
return ChatResponse(success=False, error_type="server_error", message=str(e))
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
port = int(os.getenv("PORT", 7860))
|
| 231 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
base_agent.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict, List, Any
|
| 4 |
+
import json
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from memory import BaseMemory
|
| 8 |
+
from base_tool import BaseTool
|
| 9 |
+
from llm_provider import LLMProvider, LLMResponse
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("AgentFramework")
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class AgentResponse:
|
| 15 |
+
content: str
|
| 16 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 17 |
+
|
| 18 |
+
class BaseAgent(ABC):
|
| 19 |
+
"""
|
| 20 |
+
The parent class for all agents.
|
| 21 |
+
Now accepts a clean list of 'BaseTool' objects.
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, name: str, tools: List[BaseTool], system_prompt: str = "You are a helpful assistant."):
|
| 24 |
+
self.name = name
|
| 25 |
+
self.system_prompt = system_prompt
|
| 26 |
+
|
| 27 |
+
# 1. Build the Registry (Map Name -> Function) for execution
|
| 28 |
+
self.tool_registry = {tool.name: tool.run for tool in tools}
|
| 29 |
+
|
| 30 |
+
# 2. Build the Definitions (List of Schemas) for the LLM
|
| 31 |
+
self.tool_definitions = [tool.get_schema() for tool in tools]
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def process_query(self, user_query: str, provider: LLMProvider) -> AgentResponse:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class SingleAgent(BaseAgent):
|
| 39 |
+
"""
|
| 40 |
+
A standard worker agent that uses the provided BaseTools to answer queries.
|
| 41 |
+
"""
|
| 42 |
+
def __init__(self, name: str, tools: List[BaseTool], system_prompt: str = "You are a helpful assistant."):
|
| 43 |
+
# Pass the tool objects directly to the parent
|
| 44 |
+
super().__init__(name, tools, system_prompt)
|
| 45 |
+
|
| 46 |
+
def process_query(self, user_query: str, provider: LLMProvider) -> AgentResponse:
|
| 47 |
+
messages = [
|
| 48 |
+
{"role": "system", "content": self.system_prompt},
|
| 49 |
+
{"role": "user", "content": user_query}
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
logger.info(f"\n🚀 [{self.name}] Starting Loop...")
|
| 53 |
+
|
| 54 |
+
for turn in range(5):
|
| 55 |
+
logger.info(f"--- Turn {turn + 1} ---")
|
| 56 |
+
|
| 57 |
+
# 1. Ask the Provider (Using the internally built definitions)
|
| 58 |
+
response: LLMResponse = provider.get_response(messages, self.tool_definitions)
|
| 59 |
+
|
| 60 |
+
# 2. Handle Tool Calls
|
| 61 |
+
if response.tool_call:
|
| 62 |
+
tool_name = response.tool_call["name"]
|
| 63 |
+
tool_args = response.tool_call["args"]
|
| 64 |
+
tool_id = response.tool_call.get("id", "call_default")
|
| 65 |
+
|
| 66 |
+
logger.info(f"🤖 Agent Intent: Call `{tool_name}` with {tool_args}")
|
| 67 |
+
|
| 68 |
+
if tool_name in self.tool_registry:
|
| 69 |
+
messages.append({
|
| 70 |
+
"role": "assistant",
|
| 71 |
+
"content": None,
|
| 72 |
+
"tool_calls": [{"id": tool_id, "type": "function", "function": {"name": tool_name, "arguments": json.dumps(tool_args)}}]
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Execution uses the registry built in __init__
|
| 77 |
+
tool_func = self.tool_registry[tool_name]
|
| 78 |
+
raw_result = tool_func(**tool_args)
|
| 79 |
+
result_str = json.dumps(raw_result) if not isinstance(raw_result, str) else raw_result
|
| 80 |
+
|
| 81 |
+
logger.info(f"Tool Output: {result_str}")
|
| 82 |
+
messages.append({"role": "tool", "tool_call_id": tool_id, "name": tool_name, "content": result_str})
|
| 83 |
+
|
| 84 |
+
except Exception as e:
|
| 85 |
+
error_msg = f"Tool Execution Failed: {str(e)}"
|
| 86 |
+
logger.error(error_msg)
|
| 87 |
+
messages.append({"role": "tool", "tool_call_id": tool_id, "name": tool_name, "content": error_msg})
|
| 88 |
+
continue
|
| 89 |
+
else:
|
| 90 |
+
messages.append({"role": "tool", "tool_call_id": tool_id, "name": tool_name, "content": f"❌ Unknown tool '{tool_name}'"})
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
# 3. Handle Final Answer
|
| 94 |
+
if response.content:
|
| 95 |
+
logger.info(f"[{self.name}] Final Answer: {response.content}")
|
| 96 |
+
return AgentResponse(content=response.content, metadata={"final_answer": response.content})
|
| 97 |
+
|
| 98 |
+
return AgentResponse(content="Agent timed out.", metadata={"error": "Timeout"})
|
| 99 |
+
|
| 100 |
+
class ManagerAgent(BaseAgent):
|
| 101 |
+
"""
|
| 102 |
+
The Brain.
|
| 103 |
+
It treats its sub-agents as 'Tools' and dynamically decides which one to call.
|
| 104 |
+
Now equipped with Short-Term Memory!
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self, name: str, sub_agents: Dict[str, SingleAgent], memory: BaseMemory, system_prompt: str = "You are a manager."):
|
| 107 |
+
super().__init__(name, tools=[], system_prompt=system_prompt)
|
| 108 |
+
self.sub_agents = sub_agents
|
| 109 |
+
self.memory = memory
|
| 110 |
+
self.delegation_definitions = self._build_delegation_definitions()
|
| 111 |
+
|
| 112 |
+
def _build_delegation_definitions(self) -> List[Dict]:
|
| 113 |
+
"""
|
| 114 |
+
Dynamically creates OpenAI-compatible function schemas for each sub-agent.
|
| 115 |
+
"""
|
| 116 |
+
definitions = []
|
| 117 |
+
for agent_name, agent in self.sub_agents.items():
|
| 118 |
+
agent_desc = getattr(agent, "description", "A helper agent.")
|
| 119 |
+
schema = {
|
| 120 |
+
"type": "function",
|
| 121 |
+
"function": {
|
| 122 |
+
"name": f"delegate_to_{agent_name}",
|
| 123 |
+
"description": f"Delegate a query to the {agent_name}. Capability: {agent_desc}",
|
| 124 |
+
"parameters": {
|
| 125 |
+
"type": "object",
|
| 126 |
+
"properties": {
|
| 127 |
+
"query": {
|
| 128 |
+
"type": "string",
|
| 129 |
+
"description": "The specific question or instruction for this worker."
|
| 130 |
+
}
|
| 131 |
+
},
|
| 132 |
+
"required": ["query"]
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
definitions.append(schema)
|
| 137 |
+
return definitions
|
| 138 |
+
|
| 139 |
+
def process_query(self, user_query: str, provider: LLMProvider) -> AgentResponse:
|
| 140 |
+
"""
|
| 141 |
+
The Manager's Thinking Loop.
|
| 142 |
+
It decides: Do I answer myself? Or do I call a worker?
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# 1. Save User Query to Memory
|
| 146 |
+
self.memory.add_message(role="user", content=user_query)
|
| 147 |
+
|
| 148 |
+
# 2. Construct the Context (System Prompt + History)
|
| 149 |
+
team_roster = ", ".join(self.sub_agents.keys())
|
| 150 |
+
enhanced_system_prompt = (
|
| 151 |
+
f"{self.system_prompt}\n"
|
| 152 |
+
f"You manage a team of agents: [{team_roster}].\n"
|
| 153 |
+
f"Delegate tasks to them using the available tools.\n"
|
| 154 |
+
f"Combine their outputs into a comprehensive final answer."
|
| 155 |
+
f"Use the conversation history to answer follow-up questions."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Start with System Prompt
|
| 159 |
+
messages = [{"role": "system", "content": enhanced_system_prompt}]
|
| 160 |
+
|
| 161 |
+
# Add Conversation History
|
| 162 |
+
history = self.memory.get_history()
|
| 163 |
+
messages.extend(history)
|
| 164 |
+
|
| 165 |
+
logger.info(f"👑 [{self.name}] Starting Orchestration Loop...")
|
| 166 |
+
|
| 167 |
+
# 3. Start the Loop (Max 5 turns)
|
| 168 |
+
for turn in range(5):
|
| 169 |
+
logger.info(f"--- Manager Turn {turn + 1} ---")
|
| 170 |
+
|
| 171 |
+
# A. Ask the Provider
|
| 172 |
+
response: LLMResponse = provider.get_response(messages, self.delegation_definitions)
|
| 173 |
+
|
| 174 |
+
# B. Handle "Virtual Tool" Calls (Delegation)
|
| 175 |
+
if response.tool_call:
|
| 176 |
+
tool_name = response.tool_call["name"]
|
| 177 |
+
tool_args = response.tool_call["args"]
|
| 178 |
+
tool_id = response.tool_call.get("id", "call_mgr")
|
| 179 |
+
|
| 180 |
+
if tool_name.startswith("delegate_to_"):
|
| 181 |
+
agent_name = tool_name.replace("delegate_to_", "")
|
| 182 |
+
|
| 183 |
+
if agent_name in self.sub_agents:
|
| 184 |
+
logger.info(f"👑 -> 👷 Delegating to {agent_name}: {tool_args.get('query')}")
|
| 185 |
+
|
| 186 |
+
# Record the "Thought" (Tool Call)
|
| 187 |
+
messages.append({
|
| 188 |
+
"role": "assistant",
|
| 189 |
+
"content": None,
|
| 190 |
+
"tool_calls": [{
|
| 191 |
+
"id": tool_id,
|
| 192 |
+
"type": "function",
|
| 193 |
+
"function": {"name": tool_name, "arguments": json.dumps(tool_args)}
|
| 194 |
+
}]
|
| 195 |
+
})
|
| 196 |
+
|
| 197 |
+
# EXECUTE THE WORKER
|
| 198 |
+
worker_agent = self.sub_agents[agent_name]
|
| 199 |
+
worker_query = tool_args.get("query")
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Worker runs its own loop (stateless for now)
|
| 203 |
+
worker_response = worker_agent.process_query(worker_query, provider)
|
| 204 |
+
worker_content = worker_response.content
|
| 205 |
+
logger.info(f"👷 -> 👑 {agent_name} replied.")
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
worker_content = f"Error from {agent_name}: {str(e)}"
|
| 209 |
+
logger.error(worker_content)
|
| 210 |
+
|
| 211 |
+
# Record the "Observation" (Tool Output)
|
| 212 |
+
messages.append({
|
| 213 |
+
"role": "tool",
|
| 214 |
+
"tool_call_id": tool_id,
|
| 215 |
+
"name": tool_name,
|
| 216 |
+
"content": f"Output from {agent_name}:\n{worker_content}"
|
| 217 |
+
})
|
| 218 |
+
continue
|
| 219 |
+
else:
|
| 220 |
+
logger.warning(f"❌ Manager tried to call unknown agent: {agent_name}")
|
| 221 |
+
messages.append({
|
| 222 |
+
"role": "tool",
|
| 223 |
+
"tool_call_id": tool_id,
|
| 224 |
+
"name": tool_name,
|
| 225 |
+
"content": f"Error: Agent {agent_name} does not exist."
|
| 226 |
+
})
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
# C. Handle Final Answer (Synthesis)
|
| 230 |
+
if response.content:
|
| 231 |
+
logger.info(f"✅ [{self.name}] Final Synthesis: {response.content}")
|
| 232 |
+
|
| 233 |
+
# 4. Save Assistant Answer to Memory
|
| 234 |
+
self.memory.add_message(role="assistant", content=response.content)
|
| 235 |
+
|
| 236 |
+
return AgentResponse(content=response.content)
|
| 237 |
+
|
| 238 |
+
return AgentResponse(content="Manager timed out while coordinating agents.", metadata={"error": "timeout"})
|
base_tool.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any
|
| 3 |
+
|
| 4 |
+
class BaseTool(ABC):
|
| 5 |
+
"""
|
| 6 |
+
The universal contract for any capability.
|
| 7 |
+
Enforces structure so the Agent Loader can handle any tool automatically.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
@property
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def name(self) -> str:
|
| 13 |
+
"""Unique identifier (e.g., 'get_stock_price')."""
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def description(self) -> str:
|
| 19 |
+
"""Natural language instruction for the LLM."""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def categories(self) -> List[str]:
|
| 25 |
+
"""Tags for the Registry (e.g., ['finance', 'public_api'])."""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 30 |
+
"""Returns the JSON schema for LLM function calling."""
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def run(self, **kwargs) -> Any:
|
| 35 |
+
"""Executes the tool logic."""
|
| 36 |
+
pass
|
llm_provider.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from groq import Groq
|
| 6 |
+
from google import genai
|
| 7 |
+
from google.genai import types
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ProviderError(Exception):
|
| 11 |
+
"""Base class for all provider issues."""
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
class QuotaExhaustedError(ProviderError):
|
| 15 |
+
"""Raised when we run out of credits/limit."""
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
class ProviderDownError(ProviderError):
|
| 19 |
+
"""Raised when the provider is temporarily broken (500, 429)."""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLMResponse(BaseModel):
|
| 24 |
+
content: Optional[str] = None
|
| 25 |
+
tool_call: Optional[Dict[str, Any]] = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LLMProvider(ABC):
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
|
| 31 |
+
"""
|
| 32 |
+
Args:
|
| 33 |
+
messages: Full conversation history [{"role": "user", "content": "..."}, ...]
|
| 34 |
+
tools: JSON Schema definitions for tools.
|
| 35 |
+
"""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class GroqProvider(LLMProvider):
|
| 40 |
+
def __init__(self, api_key: str, model_name: str = 'llama-3.1-8b-instant'):
|
| 41 |
+
self.client = Groq(api_key=api_key)
|
| 42 |
+
self.model_name = model_name
|
| 43 |
+
|
| 44 |
+
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
|
| 45 |
+
try:
|
| 46 |
+
# Groq/OpenAI native format
|
| 47 |
+
response = self.client.chat.completions.create(
|
| 48 |
+
model=self.model_name,
|
| 49 |
+
messages=messages,
|
| 50 |
+
tools=tools,
|
| 51 |
+
tool_choice="auto",
|
| 52 |
+
temperature=0.1
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
candidate = response.choices[0]
|
| 56 |
+
|
| 57 |
+
# Check for tool calls
|
| 58 |
+
if candidate.message.tool_calls:
|
| 59 |
+
# We take the first tool call
|
| 60 |
+
tool_call_data = candidate.message.tool_calls[0]
|
| 61 |
+
return LLMResponse(
|
| 62 |
+
tool_call={
|
| 63 |
+
"name": tool_call_data.function.name,
|
| 64 |
+
"args": json.loads(tool_call_data.function.arguments),
|
| 65 |
+
"id": tool_call_data.id # Store ID for history tracking
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Return text content
|
| 70 |
+
return LLMResponse(content=candidate.message.content)
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
error_msg = str(e).lower()
|
| 74 |
+
if "resource_exhausted" in error_msg or "quota" in error_msg:
|
| 75 |
+
raise QuotaExhaustedError("Groq Quota Exhausted")
|
| 76 |
+
else:
|
| 77 |
+
raise ProviderDownError(f"Groq Error: {e}")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GeminiProvider(LLMProvider):
|
| 81 |
+
def __init__(self, api_key: str, model_name: str = 'gemini-2.0-flash'):
|
| 82 |
+
self.client = genai.Client(api_key=api_key)
|
| 83 |
+
self.model_name = model_name
|
| 84 |
+
|
| 85 |
+
def _map_tools(self, tools: List[Dict]) -> List[types.Tool]:
|
| 86 |
+
"""
|
| 87 |
+
Converts OpenAI/Groq-style tool definitions into Gemini types.Tool objects.
|
| 88 |
+
"""
|
| 89 |
+
gemini_tools = []
|
| 90 |
+
for t in tools:
|
| 91 |
+
# Check if it matches the OpenAI schema {"type": "function", "function": {...}}
|
| 92 |
+
if t.get("type") == "function":
|
| 93 |
+
func_def = t["function"]
|
| 94 |
+
|
| 95 |
+
# Create the Gemini-specific FunctionDeclaration
|
| 96 |
+
fn_decl = types.FunctionDeclaration(
|
| 97 |
+
name=func_def["name"],
|
| 98 |
+
description=func_def.get("description"),
|
| 99 |
+
parameters=func_def.get("parameters")
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Wrap it in a Tool object
|
| 103 |
+
gemini_tools.append(types.Tool(function_declarations=[fn_decl]))
|
| 104 |
+
return gemini_tools
|
| 105 |
+
|
| 106 |
+
def _default_history_format(self, messages: List[Dict]) -> str:
|
| 107 |
+
formatted_prompt = ""
|
| 108 |
+
for msg in messages:
|
| 109 |
+
role = msg["role"]
|
| 110 |
+
content = msg.get("content", "") or ""
|
| 111 |
+
if role == "system":
|
| 112 |
+
formatted_prompt += f"System Instruction: {content}\n\n"
|
| 113 |
+
elif role == "user":
|
| 114 |
+
formatted_prompt += f"User: {content}\n"
|
| 115 |
+
elif role == "assistant":
|
| 116 |
+
if "tool_calls" in msg:
|
| 117 |
+
tc = msg["tool_calls"][0]
|
| 118 |
+
formatted_prompt += f"Assistant (Thought): I will call tool '{tc['function']['name']}' with args {tc['function']['arguments']}.\n"
|
| 119 |
+
else:
|
| 120 |
+
formatted_prompt += f"Assistant: {content}\n"
|
| 121 |
+
elif role == "tool":
|
| 122 |
+
formatted_prompt += f"Tool Output ({msg.get('name')}): {content}\n"
|
| 123 |
+
formatted_prompt += "\nBased on the history above, provide the next response or tool call."
|
| 124 |
+
return formatted_prompt
|
| 125 |
+
|
| 126 |
+
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
|
| 127 |
+
try:
|
| 128 |
+
# 1. Translate History
|
| 129 |
+
full_prompt = self._default_history_format(messages)
|
| 130 |
+
|
| 131 |
+
gemini_messages = [
|
| 132 |
+
types.Content(role="user", parts=[types.Part(text=full_prompt)])
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
# 2. Translate Tools
|
| 136 |
+
mapped_tools = self._map_tools(tools)
|
| 137 |
+
|
| 138 |
+
config = types.GenerateContentConfig(
|
| 139 |
+
tools=mapped_tools,
|
| 140 |
+
temperature=0.0
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
response = self.client.models.generate_content(
|
| 144 |
+
model=self.model_name,
|
| 145 |
+
contents=gemini_messages,
|
| 146 |
+
config=config,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
candidate = response.candidates[0]
|
| 150 |
+
function_call_part = None
|
| 151 |
+
for part in candidate.content.parts:
|
| 152 |
+
if part.function_call:
|
| 153 |
+
function_call_part = part
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
if function_call_part:
|
| 157 |
+
return LLMResponse(
|
| 158 |
+
tool_call={
|
| 159 |
+
"name": function_call_part.function_call.name,
|
| 160 |
+
"args": function_call_part.function_call.args
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return LLMResponse(content=candidate.content.parts[0].text)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
error_msg = str(e).lower()
|
| 168 |
+
if "resource_exhausted" in error_msg or "quota" in error_msg:
|
| 169 |
+
raise QuotaExhaustedError("Gemini Quota Exhausted")
|
| 170 |
+
else:
|
| 171 |
+
raise ProviderDownError(f"Gemini Error: {e}")
|
memory.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tiktoken
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger("FinancialAgent")
|
| 7 |
+
|
| 8 |
+
class BaseMemory(ABC):
|
| 9 |
+
"""
|
| 10 |
+
Abstract Base Class for memory management.
|
| 11 |
+
"""
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def add_message(self, role: str, content: str):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def get_history(self) -> List[Dict[str, str]]:
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def clear(self):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
class TokenBufferMemory(BaseMemory):
|
| 25 |
+
"""
|
| 26 |
+
Memory that keeps conversation within a strict token limit.
|
| 27 |
+
Uses a FIFO (First-In-First-Out) eviction strategy when full.
|
| 28 |
+
"""
|
| 29 |
+
def __init__(self, max_tokens: int = 4096, encoding_name: str = "cl100k_base"):
|
| 30 |
+
self.max_tokens = max_tokens
|
| 31 |
+
self.messages = []
|
| 32 |
+
# cl100k_base is the encoding for GPT-4 and acts as a good standard proxy
|
| 33 |
+
self.tokenizer = tiktoken.get_encoding(encoding_name)
|
| 34 |
+
|
| 35 |
+
def _count_tokens(self, text: str) -> int:
|
| 36 |
+
"""Helper to count tokens in a string."""
|
| 37 |
+
try:
|
| 38 |
+
return len(self.tokenizer.encode(text))
|
| 39 |
+
except Exception:
|
| 40 |
+
# Fallback for empty strings or weird encoding errors
|
| 41 |
+
return 0
|
| 42 |
+
|
| 43 |
+
def _evict_if_needed(self):
|
| 44 |
+
"""
|
| 45 |
+
Removes oldest messages until we are under the token limit.
|
| 46 |
+
Safety: Never deletes the most recent message (index -1),
|
| 47 |
+
so we always have at least the latest context.
|
| 48 |
+
"""
|
| 49 |
+
while len(self.messages) > 1:
|
| 50 |
+
current_buffer_tokens = sum(self._count_tokens(m["content"]) for m in self.messages)
|
| 51 |
+
|
| 52 |
+
if current_buffer_tokens <= self.max_tokens:
|
| 53 |
+
break
|
| 54 |
+
|
| 55 |
+
# Remove the oldest message
|
| 56 |
+
removed = self.messages.pop(0)
|
| 57 |
+
logger.info(f"🧹 Memory Full. Evicted message: {removed['role']} ({len(removed['content'])} chars)")
|
| 58 |
+
|
| 59 |
+
def add_message(self, role: str, content: str):
|
| 60 |
+
"""Adds a message and triggers eviction check."""
|
| 61 |
+
self.messages.append({"role": role, "content": content})
|
| 62 |
+
self._evict_if_needed()
|
| 63 |
+
|
| 64 |
+
def get_history(self) -> List[Dict[str, str]]:
|
| 65 |
+
return self.messages
|
| 66 |
+
|
| 67 |
+
def clear(self):
|
| 68 |
+
self.messages = []
|
registry.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from base_tool import BaseTool
|
| 3 |
+
|
| 4 |
+
class ToolRegistry:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
# Index 1: Look up by Name (for execution)
|
| 7 |
+
self._tools_by_name: Dict[str, BaseTool] = {}
|
| 8 |
+
|
| 9 |
+
# Index 2: Look up by Category (for subscription)
|
| 10 |
+
self._tools_by_category: Dict[str, List[BaseTool]] = {}
|
| 11 |
+
|
| 12 |
+
def register(self, tool: BaseTool):
|
| 13 |
+
if tool.name in self._tools_by_name:
|
| 14 |
+
raise ValueError(f"Tool '{tool.name}' is already registered.")
|
| 15 |
+
|
| 16 |
+
# 1. Add to Main Index
|
| 17 |
+
self._tools_by_name[tool.name] = tool
|
| 18 |
+
|
| 19 |
+
# 2. Add to Category Index
|
| 20 |
+
for category in tool.categories:
|
| 21 |
+
if category not in self._tools_by_category:
|
| 22 |
+
self._tools_by_category[category] = []
|
| 23 |
+
self._tools_by_category[category].append(tool)
|
| 24 |
+
|
| 25 |
+
def get_tool(self, name: str) -> BaseTool:
|
| 26 |
+
return self._tools_by_name.get(name)
|
| 27 |
+
|
| 28 |
+
def get_tools_by_category(self, category: str) -> List[BaseTool]:
|
| 29 |
+
return self._tools_by_category.get(category, [])
|
requirements.txt
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
altair==6.0.0
|
| 2 |
+
annotated-doc==0.0.4
|
| 3 |
+
annotated-types==0.7.0
|
| 4 |
+
anyio==4.12.0
|
| 5 |
+
attrs==25.4.0
|
| 6 |
+
beautifulsoup4==4.14.3
|
| 7 |
+
blinker==1.9.0
|
| 8 |
+
cachetools==6.2.2
|
| 9 |
+
certifi==2025.11.12
|
| 10 |
+
cffi==2.0.0
|
| 11 |
+
charset-normalizer==3.4.4
|
| 12 |
+
click==8.3.1
|
| 13 |
+
curl_cffi==0.13.0
|
| 14 |
+
distro==1.9.0
|
| 15 |
+
docopt==0.6.2
|
| 16 |
+
dotenv==0.9.9
|
| 17 |
+
fastapi==0.123.9
|
| 18 |
+
frozendict==2.4.7
|
| 19 |
+
gitdb==4.0.12
|
| 20 |
+
GitPython==3.1.45
|
| 21 |
+
google-ai-generativelanguage==0.6.15
|
| 22 |
+
google-api-core==2.28.1
|
| 23 |
+
google-api-python-client==2.187.0
|
| 24 |
+
google-auth==2.43.0
|
| 25 |
+
google-auth-httplib2==0.2.1
|
| 26 |
+
google-genai==1.53.0
|
| 27 |
+
googleapis-common-protos==1.72.0
|
| 28 |
+
groq==0.37.1
|
| 29 |
+
grpcio==1.76.0
|
| 30 |
+
grpcio-status==1.71.2
|
| 31 |
+
h11==0.16.0
|
| 32 |
+
httpcore==1.0.9
|
| 33 |
+
httplib2==0.31.0
|
| 34 |
+
httpx==0.28.1
|
| 35 |
+
idna==3.11
|
| 36 |
+
iniconfig==2.3.0
|
| 37 |
+
Jinja2==3.1.6
|
| 38 |
+
jsonschema==4.25.1
|
| 39 |
+
jsonschema-specifications==2025.9.1
|
| 40 |
+
MarkupSafe==3.0.3
|
| 41 |
+
multitasking==0.0.12
|
| 42 |
+
narwhals==2.13.0
|
| 43 |
+
numpy==2.3.5
|
| 44 |
+
packaging==25.0
|
| 45 |
+
pandas==2.3.3
|
| 46 |
+
peewee==3.18.3
|
| 47 |
+
pillow==12.0.0
|
| 48 |
+
pipreqs==0.4.13
|
| 49 |
+
platformdirs==4.5.1
|
| 50 |
+
pluggy==1.6.0
|
| 51 |
+
proto-plus==1.26.1
|
| 52 |
+
protobuf==5.29.5
|
| 53 |
+
pyarrow==22.0.0
|
| 54 |
+
pyasn1==0.6.1
|
| 55 |
+
pyasn1_modules==0.4.2
|
| 56 |
+
pycparser==2.23
|
| 57 |
+
pydantic==2.12.5
|
| 58 |
+
pydantic_core==2.41.5
|
| 59 |
+
pydeck==0.9.1
|
| 60 |
+
Pygments==2.19.2
|
| 61 |
+
pyparsing==3.2.5
|
| 62 |
+
pytest==9.0.2
|
| 63 |
+
python-dateutil==2.9.0.post0
|
| 64 |
+
python-dotenv==1.2.1
|
| 65 |
+
pytz==2025.2
|
| 66 |
+
referencing==0.37.0
|
| 67 |
+
regex==2025.11.3
|
| 68 |
+
requests==2.32.5
|
| 69 |
+
rpds-py==0.30.0
|
| 70 |
+
rsa==4.9.1
|
| 71 |
+
ruamel.yaml==0.18.16
|
| 72 |
+
ruamel.yaml.clib==0.2.15
|
| 73 |
+
singleton==0.1.0
|
| 74 |
+
six==1.17.0
|
| 75 |
+
smmap==5.0.2
|
| 76 |
+
sniffio==1.3.1
|
| 77 |
+
soupsieve==2.8
|
| 78 |
+
starlette==0.50.0
|
| 79 |
+
streamlit==1.52.0
|
| 80 |
+
tenacity==9.1.2
|
| 81 |
+
tiktoken==0.12.0
|
| 82 |
+
toml==0.10.2
|
| 83 |
+
tornado==6.5.2
|
| 84 |
+
tqdm==4.67.1
|
| 85 |
+
typing-inspection==0.4.2
|
| 86 |
+
typing_extensions==4.15.0
|
| 87 |
+
tzdata==2025.2
|
| 88 |
+
uritemplate==4.2.0
|
| 89 |
+
urllib3==2.5.0
|
| 90 |
+
uvicorn==0.38.0
|
| 91 |
+
websockets==15.0.1
|
| 92 |
+
yarg==0.1.10
|
| 93 |
+
yfinance==0.2.66
|
tools.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yfinance as yf
|
| 2 |
+
from typing import Dict, Any, List
|
| 3 |
+
from base_tool import BaseTool
|
| 4 |
+
from registry import ToolRegistry
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
logger = logging.getLogger("FinancialAgent")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# backend/tools.py
|
| 11 |
+
|
| 12 |
+
# backend/tools.py
|
| 13 |
+
|
| 14 |
+
from typing import Dict, Any, List
|
| 15 |
+
# import BaseTool if it's in a separate file, or assume it's available
|
| 16 |
+
|
| 17 |
+
class CalculatorTool(BaseTool):
|
| 18 |
+
name = "calculator"
|
| 19 |
+
description = "Perform basic arithmetic operations. Use this for calculating differences, ratios, or percentages."
|
| 20 |
+
categories = ["utils"]
|
| 21 |
+
|
| 22 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 23 |
+
return {
|
| 24 |
+
"type": "function",
|
| 25 |
+
"function": {
|
| 26 |
+
"name": self.name,
|
| 27 |
+
"description": self.description,
|
| 28 |
+
"parameters": {
|
| 29 |
+
"type": "object",
|
| 30 |
+
"properties": {
|
| 31 |
+
"operation": {
|
| 32 |
+
"type": "string",
|
| 33 |
+
"enum": ["add", "subtract", "multiply", "divide"],
|
| 34 |
+
"description": "The math operation to perform."
|
| 35 |
+
},
|
| 36 |
+
"x": {
|
| 37 |
+
"type": "number",
|
| 38 |
+
"description": "The first number."
|
| 39 |
+
},
|
| 40 |
+
"y": {
|
| 41 |
+
"type": "number",
|
| 42 |
+
"description": "The second number."
|
| 43 |
+
}
|
| 44 |
+
},
|
| 45 |
+
"required": ["operation", "x", "y"]
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def run(self, operation: str, x: float, y: float) -> Any:
|
| 51 |
+
try:
|
| 52 |
+
# Ensure numbers are floats (in case string is passed)
|
| 53 |
+
x, y = float(x), float(y)
|
| 54 |
+
|
| 55 |
+
if operation == "add":
|
| 56 |
+
return {"result": x + y}
|
| 57 |
+
elif operation == "subtract":
|
| 58 |
+
return {"result": x - y}
|
| 59 |
+
elif operation == "multiply":
|
| 60 |
+
return {"result": x * y}
|
| 61 |
+
elif operation == "divide":
|
| 62 |
+
if y == 0:
|
| 63 |
+
return {"error": "Error: Division by zero"}
|
| 64 |
+
return {"result": x / y}
|
| 65 |
+
else:
|
| 66 |
+
return {"error": f"Unknown operation: {operation}"}
|
| 67 |
+
except Exception as e:
|
| 68 |
+
return {"error": f"Math execution failed: {str(e)}"}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class StockPriceTool(BaseTool):
|
| 72 |
+
name = "get_stock_price"
|
| 73 |
+
description = "Get the current price of a stock using its Ticker Symbol."
|
| 74 |
+
categories = ["finance"]
|
| 75 |
+
|
| 76 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 77 |
+
return {
|
| 78 |
+
"type": "function",
|
| 79 |
+
"function": {
|
| 80 |
+
"name": self.name,
|
| 81 |
+
"description": self.description,
|
| 82 |
+
"parameters": {
|
| 83 |
+
"type": "object",
|
| 84 |
+
"properties": {
|
| 85 |
+
"ticker_symbol": {"type": "string", "description": "The stock ticker (e.g., AAPL)"}
|
| 86 |
+
},
|
| 87 |
+
"required": ["ticker_symbol"]
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def run(self, ticker_symbol: str) -> Dict:
|
| 93 |
+
try:
|
| 94 |
+
ticker = yf.Ticker(ticker_symbol)
|
| 95 |
+
info = ticker.fast_info
|
| 96 |
+
return {"ticker": ticker_symbol, "price": info.last_price, "currency": info.currency}
|
| 97 |
+
except Exception as e:
|
| 98 |
+
return {"error": str(e)}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class CompanyNewsTool(BaseTool):
|
| 102 |
+
name = "get_company_news"
|
| 103 |
+
description = "Get the latest news summaries about a company."
|
| 104 |
+
categories = ["news"]
|
| 105 |
+
|
| 106 |
+
def get_schema(self) -> Dict[str, Any]:
|
| 107 |
+
return {
|
| 108 |
+
"type": "function",
|
| 109 |
+
"function": {
|
| 110 |
+
"name": self.name,
|
| 111 |
+
"description": self.description,
|
| 112 |
+
"parameters": {
|
| 113 |
+
"type": "object",
|
| 114 |
+
"properties": {
|
| 115 |
+
"ticker_symbol": {"type": "string", "description": "The stock ticker"},
|
| 116 |
+
"num_stories": {"type": "integer", "description": "Number of stories"}
|
| 117 |
+
},
|
| 118 |
+
"required": ["ticker_symbol", "num_stories"]
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
def run(self, ticker_symbol: str, num_stories: int = 5) -> Dict:
|
| 124 |
+
try:
|
| 125 |
+
ticker = yf.Ticker(ticker_symbol)
|
| 126 |
+
news = ticker.news[:num_stories] if ticker.news else []
|
| 127 |
+
if len(news) == 0:
|
| 128 |
+
return {"error": f'HTTP Error 404: ${ticker_symbol}: possibly delisted; Quote not found for symbol'}
|
| 129 |
+
return {"ticker": ticker_symbol, "news": [n['content']['title'] for n in news], "storiesFetched": len(news)}
|
| 130 |
+
except Exception as e:
|
| 131 |
+
return {"error": str(e)}
|
| 132 |
+
|
| 133 |
+
def initialize_registry() -> ToolRegistry:
|
| 134 |
+
registry = ToolRegistry()
|
| 135 |
+
tools_list = [
|
| 136 |
+
StockPriceTool(),
|
| 137 |
+
CompanyNewsTool(),
|
| 138 |
+
CalculatorTool()
|
| 139 |
+
]
|
| 140 |
+
for tool in tools_list:
|
| 141 |
+
registry.register(tool)
|
| 142 |
+
return registry
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
|
| 147 |
+
pass
|